agent: mypy type fixes
mypy examples/agent/__main__.py mypy examples/agent/fastify.py mypy examples/openai/__main__.py
This commit is contained in:
parent
ea0c31b10b
commit
89dcc062a4
11 changed files with 109 additions and 98 deletions
|
@ -3,7 +3,7 @@ import sys
|
|||
from time import sleep
|
||||
import typer
|
||||
from pydantic import BaseModel, Json, TypeAdapter
|
||||
from typing import Annotated, Callable, List, Union, Optional, Type
|
||||
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
|
||||
import json, requests
|
||||
|
||||
from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint
|
||||
|
@ -13,7 +13,7 @@ from examples.agent.utils import collect_functions, load_module
|
|||
from examples.openai.prompting import ToolsPromptStyle
|
||||
from examples.openai.subprocesses import spawn_subprocess
|
||||
|
||||
def _get_params_schema(fn: Callable, verbose):
|
||||
def _get_params_schema(fn: Callable[[Any], Any], verbose):
|
||||
if isinstance(fn, OpenAPIMethod):
|
||||
return fn.parameters_schema
|
||||
|
||||
|
@ -26,9 +26,9 @@ def _get_params_schema(fn: Callable, verbose):
|
|||
|
||||
def completion_with_tool_usage(
|
||||
*,
|
||||
response_model: Optional[Union[Json, Type]]=None,
|
||||
response_model: Optional[Union[Json[Any], type]]=None,
|
||||
max_iterations: Optional[int]=None,
|
||||
tools: List[Callable],
|
||||
tools: List[Callable[..., Any]],
|
||||
endpoint: str,
|
||||
messages: List[Message],
|
||||
auth: Optional[str],
|
||||
|
@ -56,7 +56,7 @@ def completion_with_tool_usage(
|
|||
type="function",
|
||||
function=ToolFunction(
|
||||
name=fn.__name__,
|
||||
description=fn.__doc__,
|
||||
description=fn.__doc__ or '',
|
||||
parameters=_get_params_schema(fn, verbose=verbose)
|
||||
)
|
||||
)
|
||||
|
@ -128,7 +128,7 @@ def completion_with_tool_usage(
|
|||
def main(
|
||||
goal: Annotated[str, typer.Option()],
|
||||
tools: Optional[List[str]] = None,
|
||||
format: Annotated[str, typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
||||
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
||||
max_iterations: Optional[int] = 10,
|
||||
std_tools: Optional[bool] = False,
|
||||
auth: Optional[str] = None,
|
||||
|
@ -136,7 +136,7 @@ def main(
|
|||
verbose: bool = False,
|
||||
style: Optional[ToolsPromptStyle] = None,
|
||||
|
||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
endpoint: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
|
||||
|
@ -187,8 +187,8 @@ def main(
|
|||
sleep(5)
|
||||
|
||||
tool_functions = []
|
||||
types = {}
|
||||
for f in tools:
|
||||
types: Dict[str, type] = {}
|
||||
for f in (tools or []):
|
||||
if f.startswith('http://') or f.startswith('https://'):
|
||||
tool_functions.extend(openapi_methods_from_endpoint(f))
|
||||
else:
|
||||
|
@ -203,7 +203,7 @@ def main(
|
|||
if std_tools:
|
||||
tool_functions.extend(collect_functions(StandardTools))
|
||||
|
||||
response_model = None #str
|
||||
response_model: Union[type, Json[Any]] = None #str
|
||||
if format:
|
||||
if format in types:
|
||||
response_model = types[format]
|
||||
|
@ -246,10 +246,7 @@ def main(
|
|||
seed=seed,
|
||||
n_probs=n_probs,
|
||||
min_keep=min_keep,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": goal,
|
||||
}]
|
||||
messages=[Message(role="user", content=goal)],
|
||||
)
|
||||
print(result if response_model else f'➡️ {result}')
|
||||
# exit(0)
|
||||
|
|
|
@ -17,7 +17,7 @@ def bind_functions(app, module):
|
|||
if k == k.capitalize():
|
||||
continue
|
||||
v = getattr(module, k)
|
||||
if not callable(v) or isinstance(v, Type):
|
||||
if not callable(v) or isinstance(v, type):
|
||||
continue
|
||||
if not hasattr(v, '__annotations__'):
|
||||
continue
|
||||
|
|
|
@ -17,28 +17,29 @@ class OpenAPIMethod:
|
|||
request_body = post_descriptor.get('requestBody')
|
||||
|
||||
self.parameters = {p['name']: p for p in parameters}
|
||||
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {path}, descriptor: {json.dumps(descriptor)})'
|
||||
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})'
|
||||
|
||||
self.body = None
|
||||
self.body_name = None
|
||||
if request_body:
|
||||
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {path}, descriptor: {json.dumps(descriptor)})'
|
||||
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})'
|
||||
|
||||
body_name = 'body'
|
||||
i = 2
|
||||
while body_name in self.parameters:
|
||||
body_name = f'body{i}'
|
||||
i += 1
|
||||
|
||||
self.body = dict(
|
||||
name=body_name,
|
||||
required=request_body['required'],
|
||||
schema=request_body['content']['application/json']['schema'],
|
||||
)
|
||||
|
||||
self.body_name = 'body'
|
||||
i = 2
|
||||
while self.body_name in self.parameters:
|
||||
self.body_name = f'body{i}'
|
||||
i += 1
|
||||
|
||||
self.parameters_schema = dict(
|
||||
type='object',
|
||||
properties={
|
||||
**({
|
||||
self.body_name: self.body['schema']
|
||||
self.body['name']: self.body['schema']
|
||||
} if self.body else {}),
|
||||
**{
|
||||
name: param['schema']
|
||||
|
@ -46,14 +47,14 @@ class OpenAPIMethod:
|
|||
}
|
||||
},
|
||||
components=catalog.get('components'),
|
||||
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body_name] if self.body and self.body['required'] else [])
|
||||
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
|
||||
)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.body:
|
||||
body = kwargs.pop(self.body_name, None)
|
||||
body = kwargs.pop(self.body['name'], None)
|
||||
if self.body['required']:
|
||||
assert body is not None, f'Missing required body parameter: {self.body_name}'
|
||||
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
|
||||
else:
|
||||
body = None
|
||||
|
||||
|
|
|
@ -15,6 +15,9 @@ class Duration(BaseModel):
|
|||
months: Optional[int] = None
|
||||
years: Optional[int] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
|
||||
|
||||
@property
|
||||
def get_total_seconds(self) -> int:
|
||||
return sum([
|
||||
|
@ -29,6 +32,10 @@ class Duration(BaseModel):
|
|||
class WaitForDuration(BaseModel):
|
||||
duration: Duration
|
||||
|
||||
def __call__(self):
|
||||
sys.stderr.write(f"Waiting for {self.duration}...\n")
|
||||
time.sleep(self.duration.get_total_seconds)
|
||||
|
||||
class WaitForDate(BaseModel):
|
||||
until: date
|
||||
|
||||
|
@ -43,7 +50,7 @@ class WaitForDate(BaseModel):
|
|||
|
||||
days, seconds = time_diff.days, time_diff.seconds
|
||||
|
||||
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
|
||||
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
|
||||
time.sleep(days * 86400 + seconds)
|
||||
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
||||
|
||||
|
@ -67,8 +74,8 @@ class StandardTools:
|
|||
return _for()
|
||||
|
||||
@staticmethod
|
||||
def say_out_loud(something: str) -> str:
|
||||
def say_out_loud(something: str) -> None:
|
||||
"""
|
||||
Just says something. Used to say each thought out loud
|
||||
"""
|
||||
return subprocess.check_call(["say", something])
|
||||
subprocess.check_call(["say", something])
|
||||
|
|
|
@ -9,8 +9,10 @@ def load_source_as_module(source):
|
|||
i += 1
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, source)
|
||||
assert spec, f'Failed to load {source} as module'
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
assert spec.loader, f'{source} spec has no loader'
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
@ -29,7 +31,7 @@ def collect_functions(module):
|
|||
if k == k.capitalize():
|
||||
continue
|
||||
v = getattr(module, k)
|
||||
if not callable(v) or isinstance(v, Type):
|
||||
if not callable(v) or isinstance(v, type):
|
||||
continue
|
||||
if not hasattr(v, '__annotations__'):
|
||||
continue
|
||||
|
|
|
@ -55,9 +55,9 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item
|
|||
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list = None):
|
||||
def __init__(self, content: str, deps: List[str]):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
self.deps = deps
|
||||
|
||||
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
|
||||
|
||||
|
@ -118,7 +118,7 @@ class SchemaConverter:
|
|||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES[m.group(0)], literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
|
@ -157,13 +157,13 @@ class SchemaConverter:
|
|||
self._rules[key] = rule
|
||||
return key
|
||||
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
def resolve_refs(self, schema: Any, url: str):
|
||||
'''
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
'''
|
||||
def visit(n: dict):
|
||||
def visit(n: Any):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
|
@ -223,7 +223,7 @@ class SchemaConverter:
|
|||
|
||||
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||
pattern = pattern[1:-1]
|
||||
sub_rule_ids = {}
|
||||
sub_rule_ids: Dict[str, str] = {}
|
||||
|
||||
i = 0
|
||||
length = len(pattern)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Json, TypeAdapter
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
|
@ -16,7 +16,7 @@ class Message(BaseModel):
|
|||
name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[list[ToolCall]] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
name: str
|
||||
|
@ -29,7 +29,7 @@ class Tool(BaseModel):
|
|||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["json_object"]
|
||||
schema: Optional[Dict] = None
|
||||
schema: Optional[Json[Any]] = None # type: ignore
|
||||
|
||||
class LlamaCppParams(BaseModel):
|
||||
n_predict: Optional[int] = None
|
||||
|
@ -56,8 +56,8 @@ class LlamaCppParams(BaseModel):
|
|||
|
||||
class ChatCompletionRequest(LlamaCppParams):
|
||||
model: str
|
||||
tools: Optional[list[Tool]] = None
|
||||
messages: list[Message] = None
|
||||
tools: Optional[List[Tool]] = None
|
||||
messages: Optional[List[Message]] = None
|
||||
prompt: Optional[str] = None
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams):
|
|||
class Choice(BaseModel):
|
||||
index: int
|
||||
message: Message
|
||||
logprobs: Optional[Json] = None
|
||||
logprobs: Optional[Json[Any]] = None
|
||||
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
|
||||
|
||||
class Usage(BaseModel):
|
||||
|
@ -84,7 +84,7 @@ class ChatCompletionResponse(BaseModel):
|
|||
object: Literal["chat.completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: list[Choice]
|
||||
choices: List[Choice]
|
||||
usage: Usage
|
||||
system_fingerprint: str
|
||||
error: Optional[CompletionError] = None
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from pydantic import Json
|
||||
|
||||
from examples.openai.api import LlamaCppParams
|
||||
|
@ -9,4 +9,4 @@ class LlamaCppServerCompletionRequest(LlamaCppParams):
|
|||
cache_prompt: Optional[bool] = None
|
||||
|
||||
grammar: Optional[str] = None
|
||||
json_schema: Optional[Json] = None
|
||||
json_schema: Optional[Json[Any]] = None
|
||||
|
|
|
@ -6,12 +6,12 @@ from pathlib import Path
|
|||
import random
|
||||
import re
|
||||
import sys
|
||||
from typing import Annotated, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Annotated, Any, Optional
|
||||
from pydantic import BaseModel, Field, Json
|
||||
|
||||
from examples.json_schema_to_grammar import SchemaConverter
|
||||
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
|
||||
from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
||||
|
||||
# _THOUGHT_KEY = "thought"
|
||||
|
@ -65,7 +65,7 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
@property
|
||||
def potentially_supports_parallel_calls(self) -> bool:
|
||||
return self.formats_tool_result and self.formats_tool_name
|
||||
return bool(self.formats_tool_result and self.formats_tool_name)
|
||||
|
||||
def __init__(self, template: str, eos_token: str, bos_token: str):
|
||||
super().__init__(template=template, eos_token=eos_token, bos_token=bos_token)
|
||||
|
@ -161,7 +161,7 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def from_huggingface(model_id: str):
|
||||
from transformers import LlamaTokenizer
|
||||
from transformers import LlamaTokenizer # type: ignore
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_id)
|
||||
return ChatTemplate(
|
||||
template = tokenizer.chat_template or tokenizer.default_chat_template,
|
||||
|
@ -170,7 +170,7 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||
result = self._template.render(
|
||||
messages=messages,
|
||||
messages=[messages.model_dump() for messages in messages],
|
||||
eos_token=self.eos_token,
|
||||
bos_token='' if omit_bos else self.bos_token,
|
||||
raise_exception=raise_exception,
|
||||
|
@ -180,7 +180,7 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
class ChatHandlerArgs(BaseModel):
|
||||
chat_template: ChatTemplate
|
||||
response_schema: Optional[dict] = None
|
||||
response_schema: Optional[Json[Any]] = None
|
||||
tools: Optional[list[Tool]] = None
|
||||
|
||||
class ChatHandler(ABC):
|
||||
|
@ -199,9 +199,9 @@ class ChatHandler(ABC):
|
|||
assert system_prompt.role == "system"
|
||||
# TODO: add to last system message, or create a new one just before the last user message
|
||||
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
||||
if system_message is not None:
|
||||
if system_message:
|
||||
(i, m) = system_message
|
||||
return messages[:i] + [Message(role="system", content=system_prompt.content + '\n' + m.content)] + messages[i+1:]
|
||||
return messages[:i] + [Message(role="system", content=(system_prompt.content + '\n' if system_prompt.content else '') + (m.content or ''))] + messages[i+1:]
|
||||
else:
|
||||
return [system_prompt] + messages
|
||||
|
||||
|
@ -282,7 +282,7 @@ class ChatHandler(ABC):
|
|||
if self.args.chat_template.expects_strict_user_assistant_alternance:
|
||||
new_messages=[]
|
||||
current_role = 'user'
|
||||
current_content = []
|
||||
current_content: list[str] = []
|
||||
|
||||
def flush():
|
||||
nonlocal current_content
|
||||
|
@ -311,24 +311,24 @@ class ChatHandler(ABC):
|
|||
messages = new_messages
|
||||
|
||||
# JSON!
|
||||
messages = [m.model_dump() for m in messages]
|
||||
# messages = [m.model_dump() for m in messages]
|
||||
|
||||
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||
if self.args.chat_template.expects_stringified_function_arguments:
|
||||
messages = [
|
||||
{
|
||||
**m,
|
||||
Message(**{
|
||||
**m.model_dump(),
|
||||
"tool_calls": [
|
||||
{
|
||||
**tc,
|
||||
ToolCall(**{
|
||||
**tc.model_dump(),
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": json.dumps(tc["function"]["arguments"]),
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
}
|
||||
}
|
||||
for tc in m["tool_calls"]
|
||||
] if m.get("tool_calls") else None
|
||||
}
|
||||
})
|
||||
for tc in m.tool_calls
|
||||
] if m.tool_calls else None
|
||||
})
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
@ -364,7 +364,7 @@ class ToolCallTagsChatHandler(ChatHandler):
|
|||
|
||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
tool_rules = []
|
||||
for tool in self.args.tools:
|
||||
for tool in self.args.tools or []:
|
||||
|
||||
parameters_schema = tool.function.parameters
|
||||
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
|
||||
|
@ -416,7 +416,7 @@ class ToolCallTagsChatHandler(ChatHandler):
|
|||
if len(parts) == 1:
|
||||
return Message(role="assistant", content=s)
|
||||
else:
|
||||
content = []
|
||||
content: list[str] = []
|
||||
tool_calls = []
|
||||
for i, part in enumerate(parts):
|
||||
if i % 2 == 0:
|
||||
|
@ -431,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
|
|||
id=gen_callid(),
|
||||
function=FunctionCall(**fc)))
|
||||
|
||||
content = '\n'.join(content).strip()
|
||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
||||
content_str = '\n'.join(content).strip()
|
||||
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
|
||||
|
||||
|
||||
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||
|
@ -444,7 +444,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
|||
role="system",
|
||||
content=template.replace(
|
||||
'{tools}',
|
||||
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
|
||||
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in (self.args.tools or [])),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -456,11 +456,11 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
|||
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||
if path not in sys.path: sys.path.insert(0, path)
|
||||
try:
|
||||
from examples.openai.hermes_function_calling.prompter import PromptManager
|
||||
from examples.openai.hermes_function_calling.prompter import PromptManager # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
|
||||
|
||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools])
|
||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools or []])
|
||||
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
||||
self.output_format_prompt = Message(**prompt[0])
|
||||
|
||||
|
@ -471,7 +471,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
|||
self.output_format_prompt = Message(
|
||||
role="system",
|
||||
content= '// Supported function definitions that should be called when necessary.\n' +
|
||||
_tools_typescript_signatures(args.tools)
|
||||
_tools_typescript_signatures(args.tools or [])
|
||||
)
|
||||
|
||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
|
@ -481,7 +481,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
|||
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
||||
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
||||
converter._format_literal('\n'))
|
||||
for i, tool in enumerate(self.args.tools)
|
||||
for i, tool in enumerate(self.args.tools or [])
|
||||
]
|
||||
|
||||
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
|
||||
|
@ -583,7 +583,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
|||
response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response')
|
||||
tool_parameter_schemas = {
|
||||
tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name)
|
||||
for tool in self.args.tools
|
||||
for tool in self.args.tools or []
|
||||
}
|
||||
# sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n")
|
||||
# sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n")
|
||||
|
@ -614,7 +614,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
|||
content='\n'.join([
|
||||
'You are a function calling AI model.',
|
||||
'Here are the tools available:',
|
||||
_tools_schema_signatures(self.args.tools, indent=2),
|
||||
_tools_schema_signatures(self.args.tools or [], indent=2),
|
||||
# _tools_typescript_signatures(self.args.tools),
|
||||
_please_respond_with_schema(
|
||||
_make_bespoke_schema(
|
||||
|
@ -716,10 +716,10 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
|
|||
elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||
return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
|
||||
raise ValueError(f"Unsupported tool call style: {tool_style}")
|
||||
|
||||
# os.environ.get('NO_TS')
|
||||
def _please_respond_with_schema(schema: dict) -> str:
|
||||
def _please_respond_with_schema(schema: Json[Any]) -> str:
|
||||
sig = json.dumps(schema, indent=2)
|
||||
# _ts_converter = SchemaToTypeScriptConverter()
|
||||
# # _ts_converter.resolve_refs(schema, 'schema')
|
||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
import time
|
||||
|
||||
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
|
||||
from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage
|
||||
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
|
||||
|
||||
|
@ -21,12 +21,12 @@ def generate_id(prefix):
|
|||
return f"{prefix}{random.randint(0, 1 << 32)}"
|
||||
|
||||
def main(
|
||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf',
|
||||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 8080,
|
||||
parallel_calls: Optional[bool] = False,
|
||||
parallel_calls: bool = False,
|
||||
style: Optional[ToolsPromptStyle] = None,
|
||||
auth: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
|
@ -39,10 +39,11 @@ def main(
|
|||
|
||||
if endpoint:
|
||||
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
||||
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when using an endpoint"
|
||||
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
||||
|
||||
else:
|
||||
metadata = GGUFKeyValues(model)
|
||||
metadata = GGUFKeyValues(Path(model))
|
||||
|
||||
if not context_length:
|
||||
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
||||
|
@ -51,6 +52,7 @@ def main(
|
|||
chat_template = ChatTemplate.from_gguf(metadata)
|
||||
else:
|
||||
sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
||||
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when the model does not contain a chat template"
|
||||
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
||||
|
||||
if verbose:
|
||||
|
@ -93,9 +95,8 @@ def main(
|
|||
verbose=verbose,
|
||||
)
|
||||
|
||||
messages = chat_request.messages
|
||||
|
||||
prompt = chat_handler.render_prompt(messages)
|
||||
prompt = chat_handler.render_prompt(chat_request.messages) if chat_request.messages else chat_request.prompt
|
||||
assert prompt is not None, "One of prompt or messages field is required"
|
||||
|
||||
if verbose:
|
||||
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Any, List, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
import json
|
||||
|
||||
from pydantic import Json
|
||||
|
||||
class SchemaToTypeScriptConverter:
|
||||
# TODO: comments for arguments!
|
||||
# // Get the price of a particular car model
|
||||
|
@ -15,17 +17,18 @@ class SchemaToTypeScriptConverter:
|
|||
# location: string,
|
||||
# }) => any;
|
||||
|
||||
def __init__(self):
|
||||
self._refs = {}
|
||||
self._refs_being_resolved = set()
|
||||
def __init__(self, allow_fetch: bool = True):
|
||||
self._refs: Dict[str, Json[Any]] = {}
|
||||
self._refs_being_resolved: Set[str] = set()
|
||||
self._allow_fetch = allow_fetch
|
||||
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
def resolve_refs(self, schema: Json[Any], url: str):
|
||||
'''
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
'''
|
||||
def visit(n: dict):
|
||||
def visit(n: Json[Any]):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
|
@ -64,7 +67,7 @@ class SchemaToTypeScriptConverter:
|
|||
return n
|
||||
return visit(schema)
|
||||
|
||||
def _desc_comment(self, schema: dict):
|
||||
def _desc_comment(self, schema: Json[Any]):
|
||||
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
||||
return f'// {desc}\n' if desc else ''
|
||||
|
||||
|
@ -78,11 +81,11 @@ class SchemaToTypeScriptConverter:
|
|||
f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
||||
for prop_name, prop_schema in properties
|
||||
] + (
|
||||
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
|
||||
[f"{self._desc_comment(additional_properties) if isinstance(additional_properties, dict) else ''}[key: string]: {self.visit(additional_properties)}"]
|
||||
if additional_properties is not None else []
|
||||
)) + "\n}"
|
||||
|
||||
def visit(self, schema: dict):
|
||||
def visit(self, schema: Json[Any]):
|
||||
def print_constant(v):
|
||||
return json.dumps(v)
|
||||
|
||||
|
@ -90,7 +93,7 @@ class SchemaToTypeScriptConverter:
|
|||
schema_format = schema.get('format')
|
||||
|
||||
if 'oneOf' in schema or 'anyOf' in schema:
|
||||
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf'))
|
||||
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf') or [])
|
||||
|
||||
elif isinstance(schema_type, list):
|
||||
return '|'.join(self.visit({'type': t}) for t in schema_type)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue