agent: mypy type fixes

mypy examples/agent/__main__.py
mypy examples/agent/fastify.py
mypy examples/openai/__main__.py
This commit is contained in:
Olivier Chafik 2024-04-10 19:45:13 +01:00 committed by ochafik
parent ea0c31b10b
commit 89dcc062a4
11 changed files with 109 additions and 98 deletions

View file

@ -3,7 +3,7 @@ import sys
from time import sleep from time import sleep
import typer import typer
from pydantic import BaseModel, Json, TypeAdapter from pydantic import BaseModel, Json, TypeAdapter
from typing import Annotated, Callable, List, Union, Optional, Type from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
import json, requests import json, requests
from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint
@ -13,7 +13,7 @@ from examples.agent.utils import collect_functions, load_module
from examples.openai.prompting import ToolsPromptStyle from examples.openai.prompting import ToolsPromptStyle
from examples.openai.subprocesses import spawn_subprocess from examples.openai.subprocesses import spawn_subprocess
def _get_params_schema(fn: Callable, verbose): def _get_params_schema(fn: Callable[[Any], Any], verbose):
if isinstance(fn, OpenAPIMethod): if isinstance(fn, OpenAPIMethod):
return fn.parameters_schema return fn.parameters_schema
@ -26,9 +26,9 @@ def _get_params_schema(fn: Callable, verbose):
def completion_with_tool_usage( def completion_with_tool_usage(
*, *,
response_model: Optional[Union[Json, Type]]=None, response_model: Optional[Union[Json[Any], type]]=None,
max_iterations: Optional[int]=None, max_iterations: Optional[int]=None,
tools: List[Callable], tools: List[Callable[..., Any]],
endpoint: str, endpoint: str,
messages: List[Message], messages: List[Message],
auth: Optional[str], auth: Optional[str],
@ -56,7 +56,7 @@ def completion_with_tool_usage(
type="function", type="function",
function=ToolFunction( function=ToolFunction(
name=fn.__name__, name=fn.__name__,
description=fn.__doc__, description=fn.__doc__ or '',
parameters=_get_params_schema(fn, verbose=verbose) parameters=_get_params_schema(fn, verbose=verbose)
) )
) )
@ -128,7 +128,7 @@ def completion_with_tool_usage(
def main( def main(
goal: Annotated[str, typer.Option()], goal: Annotated[str, typer.Option()],
tools: Optional[List[str]] = None, tools: Optional[List[str]] = None,
format: Annotated[str, typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None, format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False, std_tools: Optional[bool] = False,
auth: Optional[str] = None, auth: Optional[str] = None,
@ -136,7 +136,7 @@ def main(
verbose: bool = False, verbose: bool = False,
style: Optional[ToolsPromptStyle] = None, style: Optional[ToolsPromptStyle] = None,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
# endpoint: str = 'http://localhost:8080/v1/chat/completions', # endpoint: str = 'http://localhost:8080/v1/chat/completions',
@ -187,8 +187,8 @@ def main(
sleep(5) sleep(5)
tool_functions = [] tool_functions = []
types = {} types: Dict[str, type] = {}
for f in tools: for f in (tools or []):
if f.startswith('http://') or f.startswith('https://'): if f.startswith('http://') or f.startswith('https://'):
tool_functions.extend(openapi_methods_from_endpoint(f)) tool_functions.extend(openapi_methods_from_endpoint(f))
else: else:
@ -203,7 +203,7 @@ def main(
if std_tools: if std_tools:
tool_functions.extend(collect_functions(StandardTools)) tool_functions.extend(collect_functions(StandardTools))
response_model = None #str response_model: Union[type, Json[Any]] = None #str
if format: if format:
if format in types: if format in types:
response_model = types[format] response_model = types[format]
@ -246,10 +246,7 @@ def main(
seed=seed, seed=seed,
n_probs=n_probs, n_probs=n_probs,
min_keep=min_keep, min_keep=min_keep,
messages=[{ messages=[Message(role="user", content=goal)],
"role": "user",
"content": goal,
}]
) )
print(result if response_model else f'➡️ {result}') print(result if response_model else f'➡️ {result}')
# exit(0) # exit(0)

View file

@ -17,7 +17,7 @@ def bind_functions(app, module):
if k == k.capitalize(): if k == k.capitalize():
continue continue
v = getattr(module, k) v = getattr(module, k)
if not callable(v) or isinstance(v, Type): if not callable(v) or isinstance(v, type):
continue continue
if not hasattr(v, '__annotations__'): if not hasattr(v, '__annotations__'):
continue continue

View file

@ -17,28 +17,29 @@ class OpenAPIMethod:
request_body = post_descriptor.get('requestBody') request_body = post_descriptor.get('requestBody')
self.parameters = {p['name']: p for p in parameters} self.parameters = {p['name']: p for p in parameters}
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {path}, descriptor: {json.dumps(descriptor)})' assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})'
self.body = None self.body = None
self.body_name = None
if request_body: if request_body:
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {path}, descriptor: {json.dumps(descriptor)})' assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})'
body_name = 'body'
i = 2
while body_name in self.parameters:
body_name = f'body{i}'
i += 1
self.body = dict( self.body = dict(
name=body_name,
required=request_body['required'], required=request_body['required'],
schema=request_body['content']['application/json']['schema'], schema=request_body['content']['application/json']['schema'],
) )
self.body_name = 'body'
i = 2
while self.body_name in self.parameters:
self.body_name = f'body{i}'
i += 1
self.parameters_schema = dict( self.parameters_schema = dict(
type='object', type='object',
properties={ properties={
**({ **({
self.body_name: self.body['schema'] self.body['name']: self.body['schema']
} if self.body else {}), } if self.body else {}),
**{ **{
name: param['schema'] name: param['schema']
@ -46,14 +47,14 @@ class OpenAPIMethod:
} }
}, },
components=catalog.get('components'), components=catalog.get('components'),
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body_name] if self.body and self.body['required'] else []) required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
) )
def __call__(self, **kwargs): def __call__(self, **kwargs):
if self.body: if self.body:
body = kwargs.pop(self.body_name, None) body = kwargs.pop(self.body['name'], None)
if self.body['required']: if self.body['required']:
assert body is not None, f'Missing required body parameter: {self.body_name}' assert body is not None, f'Missing required body parameter: {self.body["name"]}'
else: else:
body = None body = None

View file

@ -15,6 +15,9 @@ class Duration(BaseModel):
months: Optional[int] = None months: Optional[int] = None
years: Optional[int] = None years: Optional[int] = None
def __str__(self) -> str:
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
@property @property
def get_total_seconds(self) -> int: def get_total_seconds(self) -> int:
return sum([ return sum([
@ -29,6 +32,10 @@ class Duration(BaseModel):
class WaitForDuration(BaseModel): class WaitForDuration(BaseModel):
duration: Duration duration: Duration
def __call__(self):
sys.stderr.write(f"Waiting for {self.duration}...\n")
time.sleep(self.duration.get_total_seconds)
class WaitForDate(BaseModel): class WaitForDate(BaseModel):
until: date until: date
@ -43,7 +50,7 @@ class WaitForDate(BaseModel):
days, seconds = time_diff.days, time_diff.seconds days, seconds = time_diff.days, time_diff.seconds
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n") sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
time.sleep(days * 86400 + seconds) time.sleep(days * 86400 + seconds)
sys.stderr.write(f"Reached the target date: {self.until}\n") sys.stderr.write(f"Reached the target date: {self.until}\n")
@ -67,8 +74,8 @@ class StandardTools:
return _for() return _for()
@staticmethod @staticmethod
def say_out_loud(something: str) -> str: def say_out_loud(something: str) -> None:
""" """
Just says something. Used to say each thought out loud Just says something. Used to say each thought out loud
""" """
return subprocess.check_call(["say", something]) subprocess.check_call(["say", something])

View file

@ -9,8 +9,10 @@ def load_source_as_module(source):
i += 1 i += 1
spec = importlib.util.spec_from_file_location(module_name, source) spec = importlib.util.spec_from_file_location(module_name, source)
assert spec, f'Failed to load {source} as module'
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module sys.modules[module_name] = module
assert spec.loader, f'{source} spec has no loader'
spec.loader.exec_module(module) spec.loader.exec_module(module)
return module return module
@ -29,7 +31,7 @@ def collect_functions(module):
if k == k.capitalize(): if k == k.capitalize():
continue continue
v = getattr(module, k) v = getattr(module, k)
if not callable(v) or isinstance(v, Type): if not callable(v) or isinstance(v, type):
continue continue
if not hasattr(v, '__annotations__'): if not hasattr(v, '__annotations__'):
continue continue

View file

@ -55,9 +55,9 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item
class BuiltinRule: class BuiltinRule:
def __init__(self, content: str, deps: list = None): def __init__(self, content: str, deps: List[str]):
self.content = content self.content = content
self.deps = deps or [] self.deps = deps
_up_to_15_digits = _build_repetition('[0-9]', 0, 15) _up_to_15_digits = _build_repetition('[0-9]', 0, 15)
@ -118,7 +118,7 @@ class SchemaConverter:
def _format_literal(self, literal): def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal lambda m: GRAMMAR_LITERAL_ESCAPES[m.group(0)], literal
) )
return f'"{escaped}"' return f'"{escaped}"'
@ -157,13 +157,13 @@ class SchemaConverter:
self._rules[key] = rule self._rules[key] = rule
return key return key
def resolve_refs(self, schema: dict, url: str): def resolve_refs(self, schema: Any, url: str):
''' '''
Resolves all $ref fields in the given schema, fetching any remote schemas, Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries. respective referenced (sub)schema dictionaries.
''' '''
def visit(n: dict): def visit(n: Any):
if isinstance(n, list): if isinstance(n, list):
return [visit(x) for x in n] return [visit(x) for x in n]
elif isinstance(n, dict): elif isinstance(n, dict):
@ -223,7 +223,7 @@ class SchemaConverter:
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1] pattern = pattern[1:-1]
sub_rule_ids = {} sub_rule_ids: Dict[str, str] = {}
i = 0 i = 0
length = len(pattern) length = len(pattern)

View file

@ -1,5 +1,5 @@
from abc import ABC from abc import ABC
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Json, TypeAdapter from pydantic import BaseModel, Json, TypeAdapter
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
@ -16,7 +16,7 @@ class Message(BaseModel):
name: Optional[str] = None name: Optional[str] = None
tool_call_id: Optional[str] = None tool_call_id: Optional[str] = None
content: Optional[str] content: Optional[str]
tool_calls: Optional[list[ToolCall]] = None tool_calls: Optional[List[ToolCall]] = None
class ToolFunction(BaseModel): class ToolFunction(BaseModel):
name: str name: str
@ -29,7 +29,7 @@ class Tool(BaseModel):
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
type: Literal["json_object"] type: Literal["json_object"]
schema: Optional[Dict] = None schema: Optional[Json[Any]] = None # type: ignore
class LlamaCppParams(BaseModel): class LlamaCppParams(BaseModel):
n_predict: Optional[int] = None n_predict: Optional[int] = None
@ -56,8 +56,8 @@ class LlamaCppParams(BaseModel):
class ChatCompletionRequest(LlamaCppParams): class ChatCompletionRequest(LlamaCppParams):
model: str model: str
tools: Optional[list[Tool]] = None tools: Optional[List[Tool]] = None
messages: list[Message] = None messages: Optional[List[Message]] = None
prompt: Optional[str] = None prompt: Optional[str] = None
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams):
class Choice(BaseModel): class Choice(BaseModel):
index: int index: int
message: Message message: Message
logprobs: Optional[Json] = None logprobs: Optional[Json[Any]] = None
finish_reason: Union[Literal["stop"], Literal["tool_calls"]] finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
class Usage(BaseModel): class Usage(BaseModel):
@ -84,7 +84,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] object: Literal["chat.completion"]
created: int created: int
model: str model: str
choices: list[Choice] choices: List[Choice]
usage: Usage usage: Usage
system_fingerprint: str system_fingerprint: str
error: Optional[CompletionError] = None error: Optional[CompletionError] = None

View file

@ -1,4 +1,4 @@
from typing import Optional from typing import Any, Optional
from pydantic import Json from pydantic import Json
from examples.openai.api import LlamaCppParams from examples.openai.api import LlamaCppParams
@ -9,4 +9,4 @@ class LlamaCppServerCompletionRequest(LlamaCppParams):
cache_prompt: Optional[bool] = None cache_prompt: Optional[bool] = None
grammar: Optional[str] = None grammar: Optional[str] = None
json_schema: Optional[Json] = None json_schema: Optional[Json[Any]] = None

View file

@ -6,12 +6,12 @@ from pathlib import Path
import random import random
import re import re
import sys import sys
from typing import Annotated, Optional from typing import Annotated, Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, Json
from examples.json_schema_to_grammar import SchemaConverter from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message, FunctionCall, ToolCall from examples.openai.api import Tool, Message, FunctionCall, ToolCall
from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
from examples.openai.ts_converter import SchemaToTypeScriptConverter from examples.openai.ts_converter import SchemaToTypeScriptConverter
# _THOUGHT_KEY = "thought" # _THOUGHT_KEY = "thought"
@ -65,7 +65,7 @@ class ChatTemplate(BaseModel):
@property @property
def potentially_supports_parallel_calls(self) -> bool: def potentially_supports_parallel_calls(self) -> bool:
return self.formats_tool_result and self.formats_tool_name return bool(self.formats_tool_result and self.formats_tool_name)
def __init__(self, template: str, eos_token: str, bos_token: str): def __init__(self, template: str, eos_token: str, bos_token: str):
super().__init__(template=template, eos_token=eos_token, bos_token=bos_token) super().__init__(template=template, eos_token=eos_token, bos_token=bos_token)
@ -161,7 +161,7 @@ class ChatTemplate(BaseModel):
@staticmethod @staticmethod
def from_huggingface(model_id: str): def from_huggingface(model_id: str):
from transformers import LlamaTokenizer from transformers import LlamaTokenizer # type: ignore
tokenizer = LlamaTokenizer.from_pretrained(model_id) tokenizer = LlamaTokenizer.from_pretrained(model_id)
return ChatTemplate( return ChatTemplate(
template = tokenizer.chat_template or tokenizer.default_chat_template, template = tokenizer.chat_template or tokenizer.default_chat_template,
@ -170,7 +170,7 @@ class ChatTemplate(BaseModel):
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False): def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
result = self._template.render( result = self._template.render(
messages=messages, messages=[messages.model_dump() for messages in messages],
eos_token=self.eos_token, eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token, bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception, raise_exception=raise_exception,
@ -180,7 +180,7 @@ class ChatTemplate(BaseModel):
class ChatHandlerArgs(BaseModel): class ChatHandlerArgs(BaseModel):
chat_template: ChatTemplate chat_template: ChatTemplate
response_schema: Optional[dict] = None response_schema: Optional[Json[Any]] = None
tools: Optional[list[Tool]] = None tools: Optional[list[Tool]] = None
class ChatHandler(ABC): class ChatHandler(ABC):
@ -199,9 +199,9 @@ class ChatHandler(ABC):
assert system_prompt.role == "system" assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message # TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None) system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None: if system_message:
(i, m) = system_message (i, m) = system_message
return messages[:i] + [Message(role="system", content=system_prompt.content + '\n' + m.content)] + messages[i+1:] return messages[:i] + [Message(role="system", content=(system_prompt.content + '\n' if system_prompt.content else '') + (m.content or ''))] + messages[i+1:]
else: else:
return [system_prompt] + messages return [system_prompt] + messages
@ -282,7 +282,7 @@ class ChatHandler(ABC):
if self.args.chat_template.expects_strict_user_assistant_alternance: if self.args.chat_template.expects_strict_user_assistant_alternance:
new_messages=[] new_messages=[]
current_role = 'user' current_role = 'user'
current_content = [] current_content: list[str] = []
def flush(): def flush():
nonlocal current_content nonlocal current_content
@ -311,24 +311,24 @@ class ChatHandler(ABC):
messages = new_messages messages = new_messages
# JSON! # JSON!
messages = [m.model_dump() for m in messages] # messages = [m.model_dump() for m in messages]
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: # if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
if self.args.chat_template.expects_stringified_function_arguments: if self.args.chat_template.expects_stringified_function_arguments:
messages = [ messages = [
{ Message(**{
**m, **m.model_dump(),
"tool_calls": [ "tool_calls": [
{ ToolCall(**{
**tc, **tc.model_dump(),
"function": { "function": {
"name": tc["function"]["name"], "name": tc.function.name,
"arguments": json.dumps(tc["function"]["arguments"]), "arguments": tc.function.arguments,
} }
} })
for tc in m["tool_calls"] for tc in m.tool_calls
] if m.get("tool_calls") else None ] if m.tool_calls else None
} })
for m in messages for m in messages
] ]
@ -364,7 +364,7 @@ class ToolCallTagsChatHandler(ChatHandler):
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [] tool_rules = []
for tool in self.args.tools: for tool in self.args.tools or []:
parameters_schema = tool.function.parameters parameters_schema = tool.function.parameters
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name) parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
@ -416,7 +416,7 @@ class ToolCallTagsChatHandler(ChatHandler):
if len(parts) == 1: if len(parts) == 1:
return Message(role="assistant", content=s) return Message(role="assistant", content=s)
else: else:
content = [] content: list[str] = []
tool_calls = [] tool_calls = []
for i, part in enumerate(parts): for i, part in enumerate(parts):
if i % 2 == 0: if i % 2 == 0:
@ -431,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
id=gen_callid(), id=gen_callid(),
function=FunctionCall(**fc))) function=FunctionCall(**fc)))
content = '\n'.join(content).strip() content_str = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls) return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler): class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
@ -444,7 +444,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
role="system", role="system",
content=template.replace( content=template.replace(
'{tools}', '{tools}',
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools), '\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in (self.args.tools or [])),
) )
) )
@ -456,11 +456,11 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
path = str(Path(__file__).parent / "hermes_function_calling") path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path) if path not in sys.path: sys.path.insert(0, path)
try: try:
from examples.openai.hermes_function_calling.prompter import PromptManager from examples.openai.hermes_function_calling.prompter import PromptManager # type: ignore
except ImportError: except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`") raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools]) prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools or []])
assert len(prompt) == 1 and prompt[0]["role"] == "system" assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0]) self.output_format_prompt = Message(**prompt[0])
@ -471,7 +471,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
self.output_format_prompt = Message( self.output_format_prompt = Message(
role="system", role="system",
content= '// Supported function definitions that should be called when necessary.\n' + content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(args.tools) _tools_typescript_signatures(args.tools or [])
) )
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -481,7 +481,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' + converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n')) converter._format_literal('\n'))
for i, tool in enumerate(self.args.tools) for i, tool in enumerate(self.args.tools or [])
] ]
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>")) not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
@ -583,7 +583,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response') response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response')
tool_parameter_schemas = { tool_parameter_schemas = {
tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name) tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name)
for tool in self.args.tools for tool in self.args.tools or []
} }
# sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n") # sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n")
# sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n") # sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n")
@ -614,7 +614,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
content='\n'.join([ content='\n'.join([
'You are a function calling AI model.', 'You are a function calling AI model.',
'Here are the tools available:', 'Here are the tools available:',
_tools_schema_signatures(self.args.tools, indent=2), _tools_schema_signatures(self.args.tools or [], indent=2),
# _tools_typescript_signatures(self.args.tools), # _tools_typescript_signatures(self.args.tools),
_please_respond_with_schema( _please_respond_with_schema(
_make_bespoke_schema( _make_bespoke_schema(
@ -716,10 +716,10 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls) return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls)
else: else:
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}") raise ValueError(f"Unsupported tool call style: {tool_style}")
# os.environ.get('NO_TS') # os.environ.get('NO_TS')
def _please_respond_with_schema(schema: dict) -> str: def _please_respond_with_schema(schema: Json[Any]) -> str:
sig = json.dumps(schema, indent=2) sig = json.dumps(schema, indent=2)
# _ts_converter = SchemaToTypeScriptConverter() # _ts_converter = SchemaToTypeScriptConverter()
# # _ts_converter.resolve_refs(schema, 'schema') # # _ts_converter.resolve_refs(schema, 'schema')

View file

@ -3,7 +3,7 @@ from pathlib import Path
import time import time
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
@ -21,12 +21,12 @@ def generate_id(prefix):
return f"{prefix}{random.randint(0, 1 << 32)}" return f"{prefix}{random.randint(0, 1 << 32)}"
def main( def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf', template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf',
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost", host: str = "localhost",
port: int = 8080, port: int = 8080,
parallel_calls: Optional[bool] = False, parallel_calls: bool = False,
style: Optional[ToolsPromptStyle] = None, style: Optional[ToolsPromptStyle] = None,
auth: Optional[str] = None, auth: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
@ -39,10 +39,11 @@ def main(
if endpoint: if endpoint:
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n") sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when using an endpoint"
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback) chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
else: else:
metadata = GGUFKeyValues(model) metadata = GGUFKeyValues(Path(model))
if not context_length: if not context_length:
context_length = metadata[Keys.LLM.CONTEXT_LENGTH] context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
@ -51,6 +52,7 @@ def main(
chat_template = ChatTemplate.from_gguf(metadata) chat_template = ChatTemplate.from_gguf(metadata)
else: else:
sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n") sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when the model does not contain a chat template"
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback) chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
if verbose: if verbose:
@ -93,9 +95,8 @@ def main(
verbose=verbose, verbose=verbose,
) )
messages = chat_request.messages prompt = chat_handler.render_prompt(chat_request.messages) if chat_request.messages else chat_request.prompt
assert prompt is not None, "One of prompt or messages field is required"
prompt = chat_handler.render_prompt(messages)
if verbose: if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n') sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')

View file

@ -1,6 +1,8 @@
from typing import Any, List, Set, Tuple, Union from typing import Any, Dict, List, Set, Tuple, Union
import json import json
from pydantic import Json
class SchemaToTypeScriptConverter: class SchemaToTypeScriptConverter:
# TODO: comments for arguments! # TODO: comments for arguments!
# // Get the price of a particular car model # // Get the price of a particular car model
@ -15,17 +17,18 @@ class SchemaToTypeScriptConverter:
# location: string, # location: string,
# }) => any; # }) => any;
def __init__(self): def __init__(self, allow_fetch: bool = True):
self._refs = {} self._refs: Dict[str, Json[Any]] = {}
self._refs_being_resolved = set() self._refs_being_resolved: Set[str] = set()
self._allow_fetch = allow_fetch
def resolve_refs(self, schema: dict, url: str): def resolve_refs(self, schema: Json[Any], url: str):
''' '''
Resolves all $ref fields in the given schema, fetching any remote schemas, Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries. respective referenced (sub)schema dictionaries.
''' '''
def visit(n: dict): def visit(n: Json[Any]):
if isinstance(n, list): if isinstance(n, list):
return [visit(x) for x in n] return [visit(x) for x in n]
elif isinstance(n, dict): elif isinstance(n, dict):
@ -64,7 +67,7 @@ class SchemaToTypeScriptConverter:
return n return n
return visit(schema) return visit(schema)
def _desc_comment(self, schema: dict): def _desc_comment(self, schema: Json[Any]):
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
return f'// {desc}\n' if desc else '' return f'// {desc}\n' if desc else ''
@ -78,11 +81,11 @@ class SchemaToTypeScriptConverter:
f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}' f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
for prop_name, prop_schema in properties for prop_name, prop_schema in properties
] + ( ] + (
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"] [f"{self._desc_comment(additional_properties) if isinstance(additional_properties, dict) else ''}[key: string]: {self.visit(additional_properties)}"]
if additional_properties is not None else [] if additional_properties is not None else []
)) + "\n}" )) + "\n}"
def visit(self, schema: dict): def visit(self, schema: Json[Any]):
def print_constant(v): def print_constant(v):
return json.dumps(v) return json.dumps(v)
@ -90,7 +93,7 @@ class SchemaToTypeScriptConverter:
schema_format = schema.get('format') schema_format = schema.get('format')
if 'oneOf' in schema or 'anyOf' in schema: if 'oneOf' in schema or 'anyOf' in schema:
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf')) return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf') or [])
elif isinstance(schema_type, list): elif isinstance(schema_type, list):
return '|'.join(self.visit({'type': t}) for t in schema_type) return '|'.join(self.visit({'type': t}) for t in schema_type)