diff --git a/examples/agent/agent.py b/examples/agent/agent.py index 1cbc254fe..a283e0628 100644 --- a/examples/agent/agent.py +++ b/examples/agent/agent.py @@ -3,7 +3,7 @@ import sys from time import sleep import typer from pydantic import BaseModel, Json, TypeAdapter -from typing import Annotated, Callable, List, Union, Optional, Type +from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type import json, requests from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint @@ -13,7 +13,7 @@ from examples.agent.utils import collect_functions, load_module from examples.openai.prompting import ToolsPromptStyle from examples.openai.subprocesses import spawn_subprocess -def _get_params_schema(fn: Callable, verbose): +def _get_params_schema(fn: Callable[[Any], Any], verbose): if isinstance(fn, OpenAPIMethod): return fn.parameters_schema @@ -26,9 +26,9 @@ def _get_params_schema(fn: Callable, verbose): def completion_with_tool_usage( *, - response_model: Optional[Union[Json, Type]]=None, + response_model: Optional[Union[Json[Any], type]]=None, max_iterations: Optional[int]=None, - tools: List[Callable], + tools: List[Callable[..., Any]], endpoint: str, messages: List[Message], auth: Optional[str], @@ -56,7 +56,7 @@ def completion_with_tool_usage( type="function", function=ToolFunction( name=fn.__name__, - description=fn.__doc__, + description=fn.__doc__ or '', parameters=_get_params_schema(fn, verbose=verbose) ) ) @@ -128,7 +128,7 @@ def completion_with_tool_usage( def main( goal: Annotated[str, typer.Option()], tools: Optional[List[str]] = None, - format: Annotated[str, typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None, + format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None, max_iterations: Optional[int] = 10, std_tools: Optional[bool] = False, auth: Optional[str] = None, @@ -136,7 +136,7 @@ def main( verbose: bool = False, style: Optional[ToolsPromptStyle] = None, - model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", + model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", endpoint: Optional[str] = None, context_length: Optional[int] = None, # endpoint: str = 'http://localhost:8080/v1/chat/completions', @@ -187,8 +187,8 @@ def main( sleep(5) tool_functions = [] - types = {} - for f in tools: + types: Dict[str, type] = {} + for f in (tools or []): if f.startswith('http://') or f.startswith('https://'): tool_functions.extend(openapi_methods_from_endpoint(f)) else: @@ -203,7 +203,7 @@ def main( if std_tools: tool_functions.extend(collect_functions(StandardTools)) - response_model = None #str + response_model: Union[type, Json[Any]] = None #str if format: if format in types: response_model = types[format] @@ -246,10 +246,7 @@ def main( seed=seed, n_probs=n_probs, min_keep=min_keep, - messages=[{ - "role": "user", - "content": goal, - }] + messages=[Message(role="user", content=goal)], ) print(result if response_model else f'➡️ {result}') # exit(0) diff --git a/examples/agent/fastify.py b/examples/agent/fastify.py index 02d475b40..cf02ccc31 100644 --- a/examples/agent/fastify.py +++ b/examples/agent/fastify.py @@ -17,7 +17,7 @@ def bind_functions(app, module): if k == k.capitalize(): continue v = getattr(module, k) - if not callable(v) or isinstance(v, Type): + if not callable(v) or isinstance(v, type): continue if not hasattr(v, '__annotations__'): continue diff --git a/examples/agent/openapi_client.py b/examples/agent/openapi_client.py index 0a6980b73..d336c7436 100644 --- a/examples/agent/openapi_client.py +++ b/examples/agent/openapi_client.py @@ -17,28 +17,29 @@ class OpenAPIMethod: request_body = post_descriptor.get('requestBody') self.parameters = {p['name']: p for p in parameters} - assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {path}, descriptor: {json.dumps(descriptor)})' + assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})' self.body = None - self.body_name = None if request_body: - assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {path}, descriptor: {json.dumps(descriptor)})' + assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})' + + body_name = 'body' + i = 2 + while body_name in self.parameters: + body_name = f'body{i}' + i += 1 + self.body = dict( + name=body_name, required=request_body['required'], schema=request_body['content']['application/json']['schema'], ) - self.body_name = 'body' - i = 2 - while self.body_name in self.parameters: - self.body_name = f'body{i}' - i += 1 - self.parameters_schema = dict( type='object', properties={ **({ - self.body_name: self.body['schema'] + self.body['name']: self.body['schema'] } if self.body else {}), **{ name: param['schema'] @@ -46,14 +47,14 @@ class OpenAPIMethod: } }, components=catalog.get('components'), - required=[name for name, param in self.parameters.items() if param['required']] + ([self.body_name] if self.body and self.body['required'] else []) + required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else []) ) def __call__(self, **kwargs): if self.body: - body = kwargs.pop(self.body_name, None) + body = kwargs.pop(self.body['name'], None) if self.body['required']: - assert body is not None, f'Missing required body parameter: {self.body_name}' + assert body is not None, f'Missing required body parameter: {self.body["name"]}' else: body = None diff --git a/examples/agent/tools/std_tools.py b/examples/agent/tools/std_tools.py index 9093e8dc2..4d1e132a1 100644 --- a/examples/agent/tools/std_tools.py +++ b/examples/agent/tools/std_tools.py @@ -15,6 +15,9 @@ class Duration(BaseModel): months: Optional[int] = None years: Optional[int] = None + def __str__(self) -> str: + return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds" + @property def get_total_seconds(self) -> int: return sum([ @@ -29,6 +32,10 @@ class Duration(BaseModel): class WaitForDuration(BaseModel): duration: Duration + def __call__(self): + sys.stderr.write(f"Waiting for {self.duration}...\n") + time.sleep(self.duration.get_total_seconds) + class WaitForDate(BaseModel): until: date @@ -43,7 +50,7 @@ class WaitForDate(BaseModel): days, seconds = time_diff.days, time_diff.seconds - sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n") + sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n") time.sleep(days * 86400 + seconds) sys.stderr.write(f"Reached the target date: {self.until}\n") @@ -67,8 +74,8 @@ class StandardTools: return _for() @staticmethod - def say_out_loud(something: str) -> str: + def say_out_loud(something: str) -> None: """ Just says something. Used to say each thought out loud """ - return subprocess.check_call(["say", something]) + subprocess.check_call(["say", something]) diff --git a/examples/agent/utils.py b/examples/agent/utils.py index 4eff7f6ad..b381e8ef6 100644 --- a/examples/agent/utils.py +++ b/examples/agent/utils.py @@ -9,8 +9,10 @@ def load_source_as_module(source): i += 1 spec = importlib.util.spec_from_file_location(module_name, source) + assert spec, f'Failed to load {source} as module' module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module + assert spec.loader, f'{source} spec has no loader' spec.loader.exec_module(module) return module @@ -29,7 +31,7 @@ def collect_functions(module): if k == k.capitalize(): continue v = getattr(module, k) - if not callable(v) or isinstance(v, Type): + if not callable(v) or isinstance(v, type): continue if not hasattr(v, '__annotations__'): continue diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 826cd3f72..7ce0e13c4 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -55,9 +55,9 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item class BuiltinRule: - def __init__(self, content: str, deps: list = None): + def __init__(self, content: str, deps: List[str]): self.content = content - self.deps = deps or [] + self.deps = deps _up_to_15_digits = _build_repetition('[0-9]', 0, 15) @@ -118,7 +118,7 @@ class SchemaConverter: def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal + lambda m: GRAMMAR_LITERAL_ESCAPES[m.group(0)], literal ) return f'"{escaped}"' @@ -157,13 +157,13 @@ class SchemaConverter: self._rules[key] = rule return key - def resolve_refs(self, schema: dict, url: str): + def resolve_refs(self, schema: Any, url: str): ''' Resolves all $ref fields in the given schema, fetching any remote schemas, replacing $ref with absolute reference URL and populating self._refs with the respective referenced (sub)schema dictionaries. ''' - def visit(n: dict): + def visit(n: Any): if isinstance(n, list): return [visit(x) for x in n] elif isinstance(n, dict): @@ -223,7 +223,7 @@ class SchemaConverter: assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' pattern = pattern[1:-1] - sub_rule_ids = {} + sub_rule_ids: Dict[str, str] = {} i = 0 length = len(pattern) diff --git a/examples/openai/api.py b/examples/openai/api.py index 2de0ea686..49f4a5f7b 100644 --- a/examples/openai/api.py +++ b/examples/openai/api.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Json, TypeAdapter class FunctionCall(BaseModel): @@ -16,7 +16,7 @@ class Message(BaseModel): name: Optional[str] = None tool_call_id: Optional[str] = None content: Optional[str] - tool_calls: Optional[list[ToolCall]] = None + tool_calls: Optional[List[ToolCall]] = None class ToolFunction(BaseModel): name: str @@ -29,7 +29,7 @@ class Tool(BaseModel): class ResponseFormat(BaseModel): type: Literal["json_object"] - schema: Optional[Dict] = None + schema: Optional[Json[Any]] = None # type: ignore class LlamaCppParams(BaseModel): n_predict: Optional[int] = None @@ -56,8 +56,8 @@ class LlamaCppParams(BaseModel): class ChatCompletionRequest(LlamaCppParams): model: str - tools: Optional[list[Tool]] = None - messages: list[Message] = None + tools: Optional[List[Tool]] = None + messages: Optional[List[Message]] = None prompt: Optional[str] = None response_format: Optional[ResponseFormat] = None @@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams): class Choice(BaseModel): index: int message: Message - logprobs: Optional[Json] = None + logprobs: Optional[Json[Any]] = None finish_reason: Union[Literal["stop"], Literal["tool_calls"]] class Usage(BaseModel): @@ -84,7 +84,7 @@ class ChatCompletionResponse(BaseModel): object: Literal["chat.completion"] created: int model: str - choices: list[Choice] + choices: List[Choice] usage: Usage system_fingerprint: str error: Optional[CompletionError] = None diff --git a/examples/openai/llama_cpp_server_api.py b/examples/openai/llama_cpp_server_api.py index db934919d..db1c86041 100644 --- a/examples/openai/llama_cpp_server_api.py +++ b/examples/openai/llama_cpp_server_api.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from pydantic import Json from examples.openai.api import LlamaCppParams @@ -9,4 +9,4 @@ class LlamaCppServerCompletionRequest(LlamaCppParams): cache_prompt: Optional[bool] = None grammar: Optional[str] = None - json_schema: Optional[Json] = None + json_schema: Optional[Json[Any]] = None diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py index 10f68fdce..d64db02ff 100644 --- a/examples/openai/prompting.py +++ b/examples/openai/prompting.py @@ -6,12 +6,12 @@ from pathlib import Path import random import re import sys -from typing import Annotated, Optional -from pydantic import BaseModel, Field +from typing import Annotated, Any, Optional +from pydantic import BaseModel, Field, Json from examples.json_schema_to_grammar import SchemaConverter from examples.openai.api import Tool, Message, FunctionCall, ToolCall -from examples.openai.gguf_kvs import GGUFKeyValues, Keys +from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore from examples.openai.ts_converter import SchemaToTypeScriptConverter # _THOUGHT_KEY = "thought" @@ -65,7 +65,7 @@ class ChatTemplate(BaseModel): @property def potentially_supports_parallel_calls(self) -> bool: - return self.formats_tool_result and self.formats_tool_name + return bool(self.formats_tool_result and self.formats_tool_name) def __init__(self, template: str, eos_token: str, bos_token: str): super().__init__(template=template, eos_token=eos_token, bos_token=bos_token) @@ -161,7 +161,7 @@ class ChatTemplate(BaseModel): @staticmethod def from_huggingface(model_id: str): - from transformers import LlamaTokenizer + from transformers import LlamaTokenizer # type: ignore tokenizer = LlamaTokenizer.from_pretrained(model_id) return ChatTemplate( template = tokenizer.chat_template or tokenizer.default_chat_template, @@ -170,7 +170,7 @@ class ChatTemplate(BaseModel): def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False): result = self._template.render( - messages=messages, + messages=[messages.model_dump() for messages in messages], eos_token=self.eos_token, bos_token='' if omit_bos else self.bos_token, raise_exception=raise_exception, @@ -180,7 +180,7 @@ class ChatTemplate(BaseModel): class ChatHandlerArgs(BaseModel): chat_template: ChatTemplate - response_schema: Optional[dict] = None + response_schema: Optional[Json[Any]] = None tools: Optional[list[Tool]] = None class ChatHandler(ABC): @@ -199,9 +199,9 @@ class ChatHandler(ABC): assert system_prompt.role == "system" # TODO: add to last system message, or create a new one just before the last user message system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None) - if system_message is not None: + if system_message: (i, m) = system_message - return messages[:i] + [Message(role="system", content=system_prompt.content + '\n' + m.content)] + messages[i+1:] + return messages[:i] + [Message(role="system", content=(system_prompt.content + '\n' if system_prompt.content else '') + (m.content or ''))] + messages[i+1:] else: return [system_prompt] + messages @@ -282,7 +282,7 @@ class ChatHandler(ABC): if self.args.chat_template.expects_strict_user_assistant_alternance: new_messages=[] current_role = 'user' - current_content = [] + current_content: list[str] = [] def flush(): nonlocal current_content @@ -311,24 +311,24 @@ class ChatHandler(ABC): messages = new_messages # JSON! - messages = [m.model_dump() for m in messages] + # messages = [m.model_dump() for m in messages] # if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: if self.args.chat_template.expects_stringified_function_arguments: messages = [ - { - **m, + Message(**{ + **m.model_dump(), "tool_calls": [ - { - **tc, + ToolCall(**{ + **tc.model_dump(), "function": { - "name": tc["function"]["name"], - "arguments": json.dumps(tc["function"]["arguments"]), + "name": tc.function.name, + "arguments": tc.function.arguments, } - } - for tc in m["tool_calls"] - ] if m.get("tool_calls") else None - } + }) + for tc in m.tool_calls + ] if m.tool_calls else None + }) for m in messages ] @@ -364,7 +364,7 @@ class ToolCallTagsChatHandler(ChatHandler): converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) tool_rules = [] - for tool in self.args.tools: + for tool in self.args.tools or []: parameters_schema = tool.function.parameters parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name) @@ -416,7 +416,7 @@ class ToolCallTagsChatHandler(ChatHandler): if len(parts) == 1: return Message(role="assistant", content=s) else: - content = [] + content: list[str] = [] tool_calls = [] for i, part in enumerate(parts): if i % 2 == 0: @@ -431,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler): id=gen_callid(), function=FunctionCall(**fc))) - content = '\n'.join(content).strip() - return Message(role="assistant", content=content if content else None, tool_calls=tool_calls) + content_str = '\n'.join(content).strip() + return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls) class TemplatedToolsChatHandler(ToolCallTagsChatHandler): @@ -444,7 +444,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler): role="system", content=template.replace( '{tools}', - '\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools), + '\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in (self.args.tools or [])), ) ) @@ -456,11 +456,11 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler): path = str(Path(__file__).parent / "hermes_function_calling") if path not in sys.path: sys.path.insert(0, path) try: - from examples.openai.hermes_function_calling.prompter import PromptManager + from examples.openai.hermes_function_calling.prompter import PromptManager # type: ignore except ImportError: raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`") - prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools]) + prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools or []]) assert len(prompt) == 1 and prompt[0]["role"] == "system" self.output_format_prompt = Message(**prompt[0]) @@ -471,7 +471,7 @@ class FunctionaryToolsChatHandler(ChatHandler): self.output_format_prompt = Message( role="system", content= '// Supported function definitions that should be called when necessary.\n' + - _tools_typescript_signatures(args.tools) + _tools_typescript_signatures(args.tools or []) ) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) @@ -481,7 +481,7 @@ class FunctionaryToolsChatHandler(ChatHandler): converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' + converter._format_literal('\n')) - for i, tool in enumerate(self.args.tools) + for i, tool in enumerate(self.args.tools or []) ] not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>")) @@ -583,7 +583,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler): response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response') tool_parameter_schemas = { tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name) - for tool in self.args.tools + for tool in self.args.tools or [] } # sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n") # sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n") @@ -614,7 +614,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler): content='\n'.join([ 'You are a function calling AI model.', 'Here are the tools available:', - _tools_schema_signatures(self.args.tools, indent=2), + _tools_schema_signatures(self.args.tools or [], indent=2), # _tools_typescript_signatures(self.args.tools), _please_respond_with_schema( _make_bespoke_schema( @@ -716,10 +716,10 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls) else: - raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}") + raise ValueError(f"Unsupported tool call style: {tool_style}") # os.environ.get('NO_TS') -def _please_respond_with_schema(schema: dict) -> str: +def _please_respond_with_schema(schema: Json[Any]) -> str: sig = json.dumps(schema, indent=2) # _ts_converter = SchemaToTypeScriptConverter() # # _ts_converter.resolve_refs(schema, 'schema') diff --git a/examples/openai/server.py b/examples/openai/server.py index b03d7e098..672b6176d 100644 --- a/examples/openai/server.py +++ b/examples/openai/server.py @@ -3,7 +3,7 @@ from pathlib import Path import time from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest -from examples.openai.gguf_kvs import GGUFKeyValues, Keys +from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler @@ -21,12 +21,12 @@ def generate_id(prefix): return f"{prefix}{random.randint(0, 1 << 32)}" def main( - model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", + model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf', # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, host: str = "localhost", port: int = 8080, - parallel_calls: Optional[bool] = False, + parallel_calls: bool = False, style: Optional[ToolsPromptStyle] = None, auth: Optional[str] = None, verbose: bool = False, @@ -39,10 +39,11 @@ def main( if endpoint: sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n") + assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when using an endpoint" chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback) else: - metadata = GGUFKeyValues(model) + metadata = GGUFKeyValues(Path(model)) if not context_length: context_length = metadata[Keys.LLM.CONTEXT_LENGTH] @@ -51,6 +52,7 @@ def main( chat_template = ChatTemplate.from_gguf(metadata) else: sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n") + assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when the model does not contain a chat template" chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback) if verbose: @@ -93,9 +95,8 @@ def main( verbose=verbose, ) - messages = chat_request.messages - - prompt = chat_handler.render_prompt(messages) + prompt = chat_handler.render_prompt(chat_request.messages) if chat_request.messages else chat_request.prompt + assert prompt is not None, "One of prompt or messages field is required" if verbose: sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n') diff --git a/examples/openai/ts_converter.py b/examples/openai/ts_converter.py index 3c04bab7d..245e389c1 100644 --- a/examples/openai/ts_converter.py +++ b/examples/openai/ts_converter.py @@ -1,6 +1,8 @@ -from typing import Any, List, Set, Tuple, Union +from typing import Any, Dict, List, Set, Tuple, Union import json +from pydantic import Json + class SchemaToTypeScriptConverter: # TODO: comments for arguments! # // Get the price of a particular car model @@ -15,17 +17,18 @@ class SchemaToTypeScriptConverter: # location: string, # }) => any; - def __init__(self): - self._refs = {} - self._refs_being_resolved = set() + def __init__(self, allow_fetch: bool = True): + self._refs: Dict[str, Json[Any]] = {} + self._refs_being_resolved: Set[str] = set() + self._allow_fetch = allow_fetch - def resolve_refs(self, schema: dict, url: str): + def resolve_refs(self, schema: Json[Any], url: str): ''' Resolves all $ref fields in the given schema, fetching any remote schemas, replacing $ref with absolute reference URL and populating self._refs with the respective referenced (sub)schema dictionaries. ''' - def visit(n: dict): + def visit(n: Json[Any]): if isinstance(n, list): return [visit(x) for x in n] elif isinstance(n, dict): @@ -64,7 +67,7 @@ class SchemaToTypeScriptConverter: return n return visit(schema) - def _desc_comment(self, schema: dict): + def _desc_comment(self, schema: Json[Any]): desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None return f'// {desc}\n' if desc else '' @@ -78,11 +81,11 @@ class SchemaToTypeScriptConverter: f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}' for prop_name, prop_schema in properties ] + ( - [f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"] + [f"{self._desc_comment(additional_properties) if isinstance(additional_properties, dict) else ''}[key: string]: {self.visit(additional_properties)}"] if additional_properties is not None else [] )) + "\n}" - def visit(self, schema: dict): + def visit(self, schema: Json[Any]): def print_constant(v): return json.dumps(v) @@ -90,7 +93,7 @@ class SchemaToTypeScriptConverter: schema_format = schema.get('format') if 'oneOf' in schema or 'anyOf' in schema: - return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf')) + return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf') or []) elif isinstance(schema_type, list): return '|'.join(self.visit({'type': t}) for t in schema_type)