Fixed some issues and bugs of the grammar generator. Imporved Documentation

This commit is contained in:
Maximilian Winter 2024-01-13 05:52:30 +01:00
parent 5f719de77c
commit 0fd29f8929

View file

@ -27,7 +27,7 @@ class PydanticDataType(Enum):
"""
STRING = "string"
TRIPLE_QUOTED_STRING = "triple_quoted_string"
MARKDOWN_STRING = "markdown_string"
MARKDOWN_CODE_BLOCK = "markdown_code_block"
BOOLEAN = "boolean"
INTEGER = "integer"
FLOAT = "float"
@ -282,7 +282,7 @@ def generate_gbnf_rule_for_type(model_name, field_name,
if isclass(field_type) and issubclass(field_type, BaseModel):
nested_model_name = format_model_and_field_name(field_type.__name__)
nested_model_rules = generate_gbnf_grammar(field_type, processed_models, created_rules)
nested_model_rules,_, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
rules.extend(nested_model_rules)
gbnf_type, rules = nested_model_name, rules
elif isclass(field_type) and issubclass(field_type, Enum):
@ -290,7 +290,7 @@ def generate_gbnf_rule_for_type(model_name, field_name,
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list or field_type == list: # Array
elif get_origin(field_type) == list: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
f"{field_name}-element",
@ -343,7 +343,6 @@ def generate_gbnf_rule_for_type(model_name, field_name,
union_rules.append(union_gbnf_type)
rules.extend(union_rules_list)
elif not issubclass(union_type, NoneType):
union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(model_name,
field_name, union_type,
@ -366,10 +365,10 @@ def generate_gbnf_rule_for_type(model_name, field_name,
if field_info and hasattr(field_info, 'json_schema_extra') and field_info.json_schema_extra is not None:
triple_quoted_string = field_info.json_schema_extra.get('triple_quoted_string', False)
markdown_string = field_info.json_schema_extra.get('markdown_string', False)
markdown_string = field_info.json_schema_extra.get('markdown_code_block', False)
gbnf_type = PydanticDataType.TRIPLE_QUOTED_STRING.value if triple_quoted_string else PydanticDataType.STRING.value
gbnf_type = PydanticDataType.MARKDOWN_STRING.value if markdown_string else gbnf_type
gbnf_type = PydanticDataType.MARKDOWN_CODE_BLOCK.value if markdown_string else gbnf_type
elif field_info and hasattr(field_info, 'pattern'):
# Convert regex pattern to grammar rule
@ -473,7 +472,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
format_model_and_field_name(field_name),
field_type, is_optional,
processed_models, created_rules, field_info)
look_for_markdown_code_block = True if rule_name == "markdown_string" else False
look_for_markdown_code_block = True if rule_name == "markdown_code_block" else False
look_for_triple_quoted_string = True if rule_name == "triple_quoted_string" else False
if not look_for_markdown_code_block and not look_for_triple_quoted_string:
if rule_name not in created_rules:
@ -481,8 +480,8 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
model_rule_parts.append(f' ws \"\\\"{field_name}\\\"\" ": " {rule_name}') # Adding escaped quotes
nested_rules.extend(additional_rules)
else:
has_triple_quoted_string = look_for_markdown_code_block
has_markdown_code_block = look_for_triple_quoted_string
has_triple_quoted_string = look_for_triple_quoted_string
has_markdown_code_block = look_for_markdown_code_block
fields_joined = r' "," "\n" '.join(model_rule_parts)
model_rule = fr'{model_name} ::= "{{" "\n" {fields_joined} "\n" ws "}}"'
@ -507,7 +506,7 @@ def generate_gbnf_grammar_from_pydantic_models(models: List[Type[BaseModel]], ou
This method takes a list of Pydantic models and uses them to generate a GBNF grammar string. The generated grammar string can be used for parsing and validating data using the generated
* grammar.
Parameters:
Args:
models (List[Type[BaseModel]]): A list of Pydantic models to generate the grammar from.
outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
@ -534,20 +533,20 @@ def generate_gbnf_grammar_from_pydantic_models(models: List[Type[BaseModel]], ou
all_rules.extend(model_rules)
if list_of_outputs:
root_rule = r'root ::= ws "[" grammar-models ("," grammar-models)* "]"' + "\n"
root_rule = r'root ::= (" "| "\n") "[" grammar-models ("," grammar-models)* "]"' + "\n"
else:
root_rule = r'root ::= ws grammar-models' + "\n"
root_rule = r'root ::= (" "| "\n") grammar-models' + "\n"
root_rule += "grammar-models ::= " + " | ".join(
[format_model_and_field_name(model.__name__) for model in models])
all_rules.insert(0, root_rule)
return "\n".join(all_rules)
elif outer_object_name is not None:
if list_of_outputs:
root_rule = fr'root ::= ws "[" {format_model_and_field_name(outer_object_name)} ("," {format_model_and_field_name(outer_object_name)})* "]"' + "\n"
root_rule = fr'root ::= (" "| "\n") "[" {format_model_and_field_name(outer_object_name)} ("," {format_model_and_field_name(outer_object_name)})* "]"' + "\n"
else:
root_rule = f"root ::= {format_model_and_field_name(outer_object_name)}\n"
model_rule = fr'{format_model_and_field_name(outer_object_name)} ::= ws "{{" ws "\"{outer_object_name}\"" ": " grammar-models'
model_rule = fr'{format_model_and_field_name(outer_object_name)} ::= (" "| "\n") "{{" ws "\"{outer_object_name}\"" ": " grammar-models'
fields_joined = " | ".join(
[fr'{format_model_and_field_name(model.__name__)}-grammar-model' for model in models])
@ -1032,11 +1031,11 @@ def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
Add a 'run' method to a dynamic Pydantic model, using the provided function.
Args:
- model (Type[BaseModel]): Dynamic Pydantic model class.
- func (Callable): Function to be added as a 'run' method to the model.
model (Type[BaseModel]): Dynamic Pydantic model class.
func (Callable): Function to be added as a 'run' method to the model.
Returns:
- Type[BaseModel]: Pydantic model class with the added 'run' method.
Type[BaseModel]: Pydantic model class with the added 'run' method.
"""
def run_method_wrapper(self):
@ -1054,15 +1053,15 @@ def create_dynamic_models_from_dictionaries(dictionaries: List[dict]):
Create a list of dynamic Pydantic model classes from a list of dictionaries.
Args:
- dictionaries (List[dict]): List of dictionaries representing model structures.
dictionaries (List[dict]): List of dictionaries representing model structures.
Returns:
- List[Type[BaseModel]]: List of generated dynamic Pydantic model classes.
List[Type[BaseModel]]: List of generated dynamic Pydantic model classes.
"""
dynamic_models = []
for func in dictionaries:
model_name = format_model_and_field_name(func.get("name", ""))
dyn_model = convert_dictionary_to_to_pydantic_model(func, model_name)
dyn_model = convert_dictionary_to_pydantic_model(func, model_name)
dynamic_models.append(dyn_model)
return dynamic_models
@ -1094,40 +1093,45 @@ def list_to_enum(enum_name, values):
return Enum(enum_name, {value: value for value in values})
def convert_dictionary_to_to_pydantic_model(dictionary: dict, model_name: str = 'CustomModel') -> Type[BaseModel]:
def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = 'CustomModel') -> Type[BaseModel]:
"""
Convert a dictionary to a Pydantic model class.
Args:
- dictionary (dict): Dictionary representing the model structure.
- model_name (str): Name of the generated Pydantic model.
dictionary (dict): Dictionary representing the model structure.
model_name (str): Name of the generated Pydantic model.
Returns:
- Type[BaseModel]: Generated Pydantic model class.
Type[BaseModel]: Generated Pydantic model class.
"""
fields = {}
if "properties" in dictionary:
for field_name, field_data in dictionary.get("properties", {}).items():
if field_data == 'object':
submodel = convert_dictionary_to_to_pydantic_model(dictionary, f'{model_name}_{field_name}')
submodel = convert_dictionary_to_pydantic_model(dictionary, f'{model_name}_{field_name}')
fields[field_name] = (submodel, ...)
else:
field_type = field_data.get('type', 'str')
if field_data.get("enum", []):
fields[field_name] = (list_to_enum(field_name, field_data.get("enum", [])), ...)
if field_type == "array":
elif field_type == "array":
items = field_data.get("items", {})
if items != {}:
array = {"properties": items}
array_type = convert_dictionary_to_to_pydantic_model(array, f'{model_name}_{field_name}_items')
array_type = convert_dictionary_to_pydantic_model(array, f'{model_name}_{field_name}_items')
fields[field_name] = (List[array_type], ...)
else:
fields[field_name] = (list, ...)
elif field_type == 'object':
submodel = convert_dictionary_to_to_pydantic_model(field_data, f'{model_name}_{field_name}')
submodel = convert_dictionary_to_pydantic_model(field_data, f'{model_name}_{field_name}')
fields[field_name] = (submodel, ...)
elif field_type == 'required':
required = field_data.get("enum", [])
for key, field in fields.items():
if key not in required:
fields[key] = (Optional[fields[key][0]], ...)
else:
field_type = json_schema_to_python_types(field_type)
fields[field_name] = (field_type, ...)
@ -1139,13 +1143,15 @@ def convert_dictionary_to_to_pydantic_model(dictionary: dict, model_name: str =
elif field_name == "description":
fields["__doc__"] = field_data
elif field_name == "parameters":
return convert_dictionary_to_to_pydantic_model(field_data, f'{model_name}')
return convert_dictionary_to_pydantic_model(field_data, f'{model_name}')
if "parameters" in dictionary:
field_data = {"function": dictionary}
return convert_dictionary_to_to_pydantic_model(field_data, f'{model_name}')
return convert_dictionary_to_pydantic_model(field_data, f'{model_name}')
if 'required' in dictionary:
required = dictionary.get('required', [])
for key, field in fields.items():
if key not in required:
fields[key] = (Optional[fields[key][0]], ...)
custom_model = create_model(model_name, **fields)
return custom_model