Update pydantic-models-to-grammar.py

This commit is contained in:
Maximilian Winter 2024-01-12 19:26:28 +01:00
parent 3507238cca
commit 149d00ffd0

View file

@ -256,7 +256,7 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
return float_rule, additional_rules
def generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block, look_for_triple_quoted_string, field_name,
def generate_gbnf_rule_for_type(model_name, field_name,
field_type, is_optional, processed_models, created_rules,
field_info=None) -> \
Tuple[str, list]:
@ -264,8 +264,7 @@ def generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block, look_f
Generate GBNF rule for a given field type.
:param model_name: Name of the model.
:param look_for_markdown_code_block: Look for Markdown code block
:param look_for_triple_quoted_string
:param field_name: Name of the field.
:param field_type: Type of the field.
:param is_optional: Whether the field is optional.
@ -293,8 +292,7 @@ def generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block, look_f
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block,
look_for_triple_quoted_string,
element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
f"{field_name}-element",
element_type, is_optional, processed_models,
created_rules)
@ -305,8 +303,7 @@ def generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block, look_f
elif get_origin(field_type) == set: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block,
look_for_triple_quoted_string,
element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
f"{field_name}-element",
element_type, is_optional, processed_models,
created_rules)
@ -322,14 +319,10 @@ def generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block, look_f
key_type, value_type = get_args(field_type)
additional_key_type, additional_key_rules = generate_gbnf_rule_for_type(model_name,
look_for_markdown_code_block,
look_for_triple_quoted_string,
f"{field_name}-key-type",
key_type, is_optional, processed_models,
created_rules)
additional_value_type, additional_value_rules = generate_gbnf_rule_for_type(model_name,
look_for_markdown_code_block,
look_for_triple_quoted_string,
f"{field_name}-value-type",
value_type, is_optional,
processed_models, created_rules)
@ -344,8 +337,6 @@ def generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block, look_f
for union_type in union_types:
if not issubclass(union_type, NoneType):
union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(model_name,
look_for_markdown_code_block,
look_for_triple_quoted_string,
field_name, union_type,
False,
processed_models, created_rules)
@ -469,8 +460,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
field_type = field_info
field_info = model.model_fields[field_name]
is_optional = field_info.is_required is False and get_origin(field_type) is Optional
rule_name, additional_rules = generate_gbnf_rule_for_type(model_name, look_for_markdown_code_block,
look_for_triple_quoted_string,
rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
format_model_and_field_name(field_name),
field_type, is_optional,
processed_models, created_rules, field_info)