agent: mypy type fixes

mypy examples/agent/__main__.py
mypy examples/agent/fastify.py
mypy examples/openai/__main__.py
This commit is contained in:
Olivier Chafik 2024-04-10 19:45:13 +01:00 committed by ochafik
parent ea0c31b10b
commit 89dcc062a4
11 changed files with 109 additions and 98 deletions

View file

@ -3,7 +3,7 @@ import sys
from time import sleep
import typer
from pydantic import BaseModel, Json, TypeAdapter
from typing import Annotated, Callable, List, Union, Optional, Type
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
import json, requests
from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint
@ -13,7 +13,7 @@ from examples.agent.utils import collect_functions, load_module
from examples.openai.prompting import ToolsPromptStyle
from examples.openai.subprocesses import spawn_subprocess
def _get_params_schema(fn: Callable, verbose):
def _get_params_schema(fn: Callable[[Any], Any], verbose):
if isinstance(fn, OpenAPIMethod):
return fn.parameters_schema
@ -26,9 +26,9 @@ def _get_params_schema(fn: Callable, verbose):
def completion_with_tool_usage(
*,
response_model: Optional[Union[Json, Type]]=None,
response_model: Optional[Union[Json[Any], type]]=None,
max_iterations: Optional[int]=None,
tools: List[Callable],
tools: List[Callable[..., Any]],
endpoint: str,
messages: List[Message],
auth: Optional[str],
@ -56,7 +56,7 @@ def completion_with_tool_usage(
type="function",
function=ToolFunction(
name=fn.__name__,
description=fn.__doc__,
description=fn.__doc__ or '',
parameters=_get_params_schema(fn, verbose=verbose)
)
)
@ -128,7 +128,7 @@ def completion_with_tool_usage(
def main(
goal: Annotated[str, typer.Option()],
tools: Optional[List[str]] = None,
format: Annotated[str, typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False,
auth: Optional[str] = None,
@ -136,7 +136,7 @@ def main(
verbose: bool = False,
style: Optional[ToolsPromptStyle] = None,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
@ -187,8 +187,8 @@ def main(
sleep(5)
tool_functions = []
types = {}
for f in tools:
types: Dict[str, type] = {}
for f in (tools or []):
if f.startswith('http://') or f.startswith('https://'):
tool_functions.extend(openapi_methods_from_endpoint(f))
else:
@ -203,7 +203,7 @@ def main(
if std_tools:
tool_functions.extend(collect_functions(StandardTools))
response_model = None #str
response_model: Union[type, Json[Any]] = None #str
if format:
if format in types:
response_model = types[format]
@ -246,10 +246,7 @@ def main(
seed=seed,
n_probs=n_probs,
min_keep=min_keep,
messages=[{
"role": "user",
"content": goal,
}]
messages=[Message(role="user", content=goal)],
)
print(result if response_model else f'➡️ {result}')
# exit(0)

View file

@ -17,7 +17,7 @@ def bind_functions(app, module):
if k == k.capitalize():
continue
v = getattr(module, k)
if not callable(v) or isinstance(v, Type):
if not callable(v) or isinstance(v, type):
continue
if not hasattr(v, '__annotations__'):
continue

View file

@ -17,28 +17,29 @@ class OpenAPIMethod:
request_body = post_descriptor.get('requestBody')
self.parameters = {p['name']: p for p in parameters}
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {path}, descriptor: {json.dumps(descriptor)})'
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})'
self.body = None
self.body_name = None
if request_body:
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {path}, descriptor: {json.dumps(descriptor)})'
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})'
body_name = 'body'
i = 2
while body_name in self.parameters:
body_name = f'body{i}'
i += 1
self.body = dict(
name=body_name,
required=request_body['required'],
schema=request_body['content']['application/json']['schema'],
)
self.body_name = 'body'
i = 2
while self.body_name in self.parameters:
self.body_name = f'body{i}'
i += 1
self.parameters_schema = dict(
type='object',
properties={
**({
self.body_name: self.body['schema']
self.body['name']: self.body['schema']
} if self.body else {}),
**{
name: param['schema']
@ -46,14 +47,14 @@ class OpenAPIMethod:
}
},
components=catalog.get('components'),
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body_name] if self.body and self.body['required'] else [])
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
)
def __call__(self, **kwargs):
if self.body:
body = kwargs.pop(self.body_name, None)
body = kwargs.pop(self.body['name'], None)
if self.body['required']:
assert body is not None, f'Missing required body parameter: {self.body_name}'
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
else:
body = None

View file

@ -15,6 +15,9 @@ class Duration(BaseModel):
months: Optional[int] = None
years: Optional[int] = None
def __str__(self) -> str:
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
@property
def get_total_seconds(self) -> int:
return sum([
@ -29,6 +32,10 @@ class Duration(BaseModel):
class WaitForDuration(BaseModel):
duration: Duration
def __call__(self):
sys.stderr.write(f"Waiting for {self.duration}...\n")
time.sleep(self.duration.get_total_seconds)
class WaitForDate(BaseModel):
until: date
@ -43,7 +50,7 @@ class WaitForDate(BaseModel):
days, seconds = time_diff.days, time_diff.seconds
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
time.sleep(days * 86400 + seconds)
sys.stderr.write(f"Reached the target date: {self.until}\n")
@ -67,8 +74,8 @@ class StandardTools:
return _for()
@staticmethod
def say_out_loud(something: str) -> str:
def say_out_loud(something: str) -> None:
"""
Just says something. Used to say each thought out loud
"""
return subprocess.check_call(["say", something])
subprocess.check_call(["say", something])

View file

@ -9,8 +9,10 @@ def load_source_as_module(source):
i += 1
spec = importlib.util.spec_from_file_location(module_name, source)
assert spec, f'Failed to load {source} as module'
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader, f'{source} spec has no loader'
spec.loader.exec_module(module)
return module
@ -29,7 +31,7 @@ def collect_functions(module):
if k == k.capitalize():
continue
v = getattr(module, k)
if not callable(v) or isinstance(v, Type):
if not callable(v) or isinstance(v, type):
continue
if not hasattr(v, '__annotations__'):
continue

View file

@ -55,9 +55,9 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item
class BuiltinRule:
def __init__(self, content: str, deps: list = None):
def __init__(self, content: str, deps: List[str]):
self.content = content
self.deps = deps or []
self.deps = deps
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
@ -118,7 +118,7 @@ class SchemaConverter:
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
lambda m: GRAMMAR_LITERAL_ESCAPES[m.group(0)], literal
)
return f'"{escaped}"'
@ -157,13 +157,13 @@ class SchemaConverter:
self._rules[key] = rule
return key
def resolve_refs(self, schema: dict, url: str):
def resolve_refs(self, schema: Any, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
def visit(n: Any):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
@ -223,7 +223,7 @@ class SchemaConverter:
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
sub_rule_ids: Dict[str, str] = {}
i = 0
length = len(pattern)

View file

@ -1,5 +1,5 @@
from abc import ABC
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Json, TypeAdapter
class FunctionCall(BaseModel):
@ -16,7 +16,7 @@ class Message(BaseModel):
name: Optional[str] = None
tool_call_id: Optional[str] = None
content: Optional[str]
tool_calls: Optional[list[ToolCall]] = None
tool_calls: Optional[List[ToolCall]] = None
class ToolFunction(BaseModel):
name: str
@ -29,7 +29,7 @@ class Tool(BaseModel):
class ResponseFormat(BaseModel):
type: Literal["json_object"]
schema: Optional[Dict] = None
schema: Optional[Json[Any]] = None # type: ignore
class LlamaCppParams(BaseModel):
n_predict: Optional[int] = None
@ -56,8 +56,8 @@ class LlamaCppParams(BaseModel):
class ChatCompletionRequest(LlamaCppParams):
model: str
tools: Optional[list[Tool]] = None
messages: list[Message] = None
tools: Optional[List[Tool]] = None
messages: Optional[List[Message]] = None
prompt: Optional[str] = None
response_format: Optional[ResponseFormat] = None
@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams):
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[Json] = None
logprobs: Optional[Json[Any]] = None
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
class Usage(BaseModel):
@ -84,7 +84,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"]
created: int
model: str
choices: list[Choice]
choices: List[Choice]
usage: Usage
system_fingerprint: str
error: Optional[CompletionError] = None

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
from pydantic import Json
from examples.openai.api import LlamaCppParams
@ -9,4 +9,4 @@ class LlamaCppServerCompletionRequest(LlamaCppParams):
cache_prompt: Optional[bool] = None
grammar: Optional[str] = None
json_schema: Optional[Json] = None
json_schema: Optional[Json[Any]] = None

View file

@ -6,12 +6,12 @@ from pathlib import Path
import random
import re
import sys
from typing import Annotated, Optional
from pydantic import BaseModel, Field
from typing import Annotated, Any, Optional
from pydantic import BaseModel, Field, Json
from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
from examples.openai.ts_converter import SchemaToTypeScriptConverter
# _THOUGHT_KEY = "thought"
@ -65,7 +65,7 @@ class ChatTemplate(BaseModel):
@property
def potentially_supports_parallel_calls(self) -> bool:
return self.formats_tool_result and self.formats_tool_name
return bool(self.formats_tool_result and self.formats_tool_name)
def __init__(self, template: str, eos_token: str, bos_token: str):
super().__init__(template=template, eos_token=eos_token, bos_token=bos_token)
@ -161,7 +161,7 @@ class ChatTemplate(BaseModel):
@staticmethod
def from_huggingface(model_id: str):
from transformers import LlamaTokenizer
from transformers import LlamaTokenizer # type: ignore
tokenizer = LlamaTokenizer.from_pretrained(model_id)
return ChatTemplate(
template = tokenizer.chat_template or tokenizer.default_chat_template,
@ -170,7 +170,7 @@ class ChatTemplate(BaseModel):
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
result = self._template.render(
messages=messages,
messages=[messages.model_dump() for messages in messages],
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception,
@ -180,7 +180,7 @@ class ChatTemplate(BaseModel):
class ChatHandlerArgs(BaseModel):
chat_template: ChatTemplate
response_schema: Optional[dict] = None
response_schema: Optional[Json[Any]] = None
tools: Optional[list[Tool]] = None
class ChatHandler(ABC):
@ -199,9 +199,9 @@ class ChatHandler(ABC):
assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
if system_message:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=system_prompt.content + '\n' + m.content)] + messages[i+1:]
return messages[:i] + [Message(role="system", content=(system_prompt.content + '\n' if system_prompt.content else '') + (m.content or ''))] + messages[i+1:]
else:
return [system_prompt] + messages
@ -282,7 +282,7 @@ class ChatHandler(ABC):
if self.args.chat_template.expects_strict_user_assistant_alternance:
new_messages=[]
current_role = 'user'
current_content = []
current_content: list[str] = []
def flush():
nonlocal current_content
@ -311,24 +311,24 @@ class ChatHandler(ABC):
messages = new_messages
# JSON!
messages = [m.model_dump() for m in messages]
# messages = [m.model_dump() for m in messages]
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
if self.args.chat_template.expects_stringified_function_arguments:
messages = [
{
**m,
Message(**{
**m.model_dump(),
"tool_calls": [
{
**tc,
ToolCall(**{
**tc.model_dump(),
"function": {
"name": tc["function"]["name"],
"arguments": json.dumps(tc["function"]["arguments"]),
"name": tc.function.name,
"arguments": tc.function.arguments,
}
}
for tc in m["tool_calls"]
] if m.get("tool_calls") else None
}
})
for tc in m.tool_calls
] if m.tool_calls else None
})
for m in messages
]
@ -364,7 +364,7 @@ class ToolCallTagsChatHandler(ChatHandler):
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = []
for tool in self.args.tools:
for tool in self.args.tools or []:
parameters_schema = tool.function.parameters
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
@ -416,7 +416,7 @@ class ToolCallTagsChatHandler(ChatHandler):
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
content = []
content: list[str] = []
tool_calls = []
for i, part in enumerate(parts):
if i % 2 == 0:
@ -431,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
id=gen_callid(),
function=FunctionCall(**fc)))
content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
content_str = '\n'.join(content).strip()
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
@ -444,7 +444,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
role="system",
content=template.replace(
'{tools}',
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in (self.args.tools or [])),
)
)
@ -456,11 +456,11 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path)
try:
from examples.openai.hermes_function_calling.prompter import PromptManager
from examples.openai.hermes_function_calling.prompter import PromptManager # type: ignore
except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools])
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools or []])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0])
@ -471,7 +471,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
self.output_format_prompt = Message(
role="system",
content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(args.tools)
_tools_typescript_signatures(args.tools or [])
)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -481,7 +481,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n'))
for i, tool in enumerate(self.args.tools)
for i, tool in enumerate(self.args.tools or [])
]
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
@ -583,7 +583,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response')
tool_parameter_schemas = {
tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name)
for tool in self.args.tools
for tool in self.args.tools or []
}
# sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n")
# sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n")
@ -614,7 +614,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
content='\n'.join([
'You are a function calling AI model.',
'Here are the tools available:',
_tools_schema_signatures(self.args.tools, indent=2),
_tools_schema_signatures(self.args.tools or [], indent=2),
# _tools_typescript_signatures(self.args.tools),
_please_respond_with_schema(
_make_bespoke_schema(
@ -716,10 +716,10 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls)
else:
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
raise ValueError(f"Unsupported tool call style: {tool_style}")
# os.environ.get('NO_TS')
def _please_respond_with_schema(schema: dict) -> str:
def _please_respond_with_schema(schema: Json[Any]) -> str:
sig = json.dumps(schema, indent=2)
# _ts_converter = SchemaToTypeScriptConverter()
# # _ts_converter.resolve_refs(schema, 'schema')

View file

@ -3,7 +3,7 @@ from pathlib import Path
import time
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
@ -21,12 +21,12 @@ def generate_id(prefix):
return f"{prefix}{random.randint(0, 1 << 32)}"
def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf',
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
parallel_calls: Optional[bool] = False,
parallel_calls: bool = False,
style: Optional[ToolsPromptStyle] = None,
auth: Optional[str] = None,
verbose: bool = False,
@ -39,10 +39,11 @@ def main(
if endpoint:
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when using an endpoint"
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
else:
metadata = GGUFKeyValues(model)
metadata = GGUFKeyValues(Path(model))
if not context_length:
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
@ -51,6 +52,7 @@ def main(
chat_template = ChatTemplate.from_gguf(metadata)
else:
sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when the model does not contain a chat template"
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
if verbose:
@ -93,9 +95,8 @@ def main(
verbose=verbose,
)
messages = chat_request.messages
prompt = chat_handler.render_prompt(messages)
prompt = chat_handler.render_prompt(chat_request.messages) if chat_request.messages else chat_request.prompt
assert prompt is not None, "One of prompt or messages field is required"
if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')

View file

@ -1,6 +1,8 @@
from typing import Any, List, Set, Tuple, Union
from typing import Any, Dict, List, Set, Tuple, Union
import json
from pydantic import Json
class SchemaToTypeScriptConverter:
# TODO: comments for arguments!
# // Get the price of a particular car model
@ -15,17 +17,18 @@ class SchemaToTypeScriptConverter:
# location: string,
# }) => any;
def __init__(self):
self._refs = {}
self._refs_being_resolved = set()
def __init__(self, allow_fetch: bool = True):
self._refs: Dict[str, Json[Any]] = {}
self._refs_being_resolved: Set[str] = set()
self._allow_fetch = allow_fetch
def resolve_refs(self, schema: dict, url: str):
def resolve_refs(self, schema: Json[Any], url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
def visit(n: Json[Any]):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
@ -64,7 +67,7 @@ class SchemaToTypeScriptConverter:
return n
return visit(schema)
def _desc_comment(self, schema: dict):
def _desc_comment(self, schema: Json[Any]):
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
return f'// {desc}\n' if desc else ''
@ -78,11 +81,11 @@ class SchemaToTypeScriptConverter:
f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
for prop_name, prop_schema in properties
] + (
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
[f"{self._desc_comment(additional_properties) if isinstance(additional_properties, dict) else ''}[key: string]: {self.visit(additional_properties)}"]
if additional_properties is not None else []
)) + "\n}"
def visit(self, schema: dict):
def visit(self, schema: Json[Any]):
def print_constant(v):
return json.dumps(v)
@ -90,7 +93,7 @@ class SchemaToTypeScriptConverter:
schema_format = schema.get('format')
if 'oneOf' in schema or 'anyOf' in schema:
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf'))
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf') or [])
elif isinstance(schema_type, list):
return '|'.join(self.visit({'type': t}) for t in schema_type)