agent: mypy type fixes
mypy examples/agent/__main__.py mypy examples/agent/fastify.py mypy examples/openai/__main__.py
This commit is contained in:
parent
ea0c31b10b
commit
89dcc062a4
11 changed files with 109 additions and 98 deletions
|
@ -3,7 +3,7 @@ import sys
|
||||||
from time import sleep
|
from time import sleep
|
||||||
import typer
|
import typer
|
||||||
from pydantic import BaseModel, Json, TypeAdapter
|
from pydantic import BaseModel, Json, TypeAdapter
|
||||||
from typing import Annotated, Callable, List, Union, Optional, Type
|
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
|
||||||
import json, requests
|
import json, requests
|
||||||
|
|
||||||
from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint
|
from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint
|
||||||
|
@ -13,7 +13,7 @@ from examples.agent.utils import collect_functions, load_module
|
||||||
from examples.openai.prompting import ToolsPromptStyle
|
from examples.openai.prompting import ToolsPromptStyle
|
||||||
from examples.openai.subprocesses import spawn_subprocess
|
from examples.openai.subprocesses import spawn_subprocess
|
||||||
|
|
||||||
def _get_params_schema(fn: Callable, verbose):
|
def _get_params_schema(fn: Callable[[Any], Any], verbose):
|
||||||
if isinstance(fn, OpenAPIMethod):
|
if isinstance(fn, OpenAPIMethod):
|
||||||
return fn.parameters_schema
|
return fn.parameters_schema
|
||||||
|
|
||||||
|
@ -26,9 +26,9 @@ def _get_params_schema(fn: Callable, verbose):
|
||||||
|
|
||||||
def completion_with_tool_usage(
|
def completion_with_tool_usage(
|
||||||
*,
|
*,
|
||||||
response_model: Optional[Union[Json, Type]]=None,
|
response_model: Optional[Union[Json[Any], type]]=None,
|
||||||
max_iterations: Optional[int]=None,
|
max_iterations: Optional[int]=None,
|
||||||
tools: List[Callable],
|
tools: List[Callable[..., Any]],
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
auth: Optional[str],
|
auth: Optional[str],
|
||||||
|
@ -56,7 +56,7 @@ def completion_with_tool_usage(
|
||||||
type="function",
|
type="function",
|
||||||
function=ToolFunction(
|
function=ToolFunction(
|
||||||
name=fn.__name__,
|
name=fn.__name__,
|
||||||
description=fn.__doc__,
|
description=fn.__doc__ or '',
|
||||||
parameters=_get_params_schema(fn, verbose=verbose)
|
parameters=_get_params_schema(fn, verbose=verbose)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -128,7 +128,7 @@ def completion_with_tool_usage(
|
||||||
def main(
|
def main(
|
||||||
goal: Annotated[str, typer.Option()],
|
goal: Annotated[str, typer.Option()],
|
||||||
tools: Optional[List[str]] = None,
|
tools: Optional[List[str]] = None,
|
||||||
format: Annotated[str, typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
||||||
max_iterations: Optional[int] = 10,
|
max_iterations: Optional[int] = 10,
|
||||||
std_tools: Optional[bool] = False,
|
std_tools: Optional[bool] = False,
|
||||||
auth: Optional[str] = None,
|
auth: Optional[str] = None,
|
||||||
|
@ -136,7 +136,7 @@ def main(
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
style: Optional[ToolsPromptStyle] = None,
|
style: Optional[ToolsPromptStyle] = None,
|
||||||
|
|
||||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
|
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
|
||||||
|
@ -187,8 +187,8 @@ def main(
|
||||||
sleep(5)
|
sleep(5)
|
||||||
|
|
||||||
tool_functions = []
|
tool_functions = []
|
||||||
types = {}
|
types: Dict[str, type] = {}
|
||||||
for f in tools:
|
for f in (tools or []):
|
||||||
if f.startswith('http://') or f.startswith('https://'):
|
if f.startswith('http://') or f.startswith('https://'):
|
||||||
tool_functions.extend(openapi_methods_from_endpoint(f))
|
tool_functions.extend(openapi_methods_from_endpoint(f))
|
||||||
else:
|
else:
|
||||||
|
@ -203,7 +203,7 @@ def main(
|
||||||
if std_tools:
|
if std_tools:
|
||||||
tool_functions.extend(collect_functions(StandardTools))
|
tool_functions.extend(collect_functions(StandardTools))
|
||||||
|
|
||||||
response_model = None #str
|
response_model: Union[type, Json[Any]] = None #str
|
||||||
if format:
|
if format:
|
||||||
if format in types:
|
if format in types:
|
||||||
response_model = types[format]
|
response_model = types[format]
|
||||||
|
@ -246,10 +246,7 @@ def main(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
n_probs=n_probs,
|
n_probs=n_probs,
|
||||||
min_keep=min_keep,
|
min_keep=min_keep,
|
||||||
messages=[{
|
messages=[Message(role="user", content=goal)],
|
||||||
"role": "user",
|
|
||||||
"content": goal,
|
|
||||||
}]
|
|
||||||
)
|
)
|
||||||
print(result if response_model else f'➡️ {result}')
|
print(result if response_model else f'➡️ {result}')
|
||||||
# exit(0)
|
# exit(0)
|
||||||
|
|
|
@ -17,7 +17,7 @@ def bind_functions(app, module):
|
||||||
if k == k.capitalize():
|
if k == k.capitalize():
|
||||||
continue
|
continue
|
||||||
v = getattr(module, k)
|
v = getattr(module, k)
|
||||||
if not callable(v) or isinstance(v, Type):
|
if not callable(v) or isinstance(v, type):
|
||||||
continue
|
continue
|
||||||
if not hasattr(v, '__annotations__'):
|
if not hasattr(v, '__annotations__'):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -17,28 +17,29 @@ class OpenAPIMethod:
|
||||||
request_body = post_descriptor.get('requestBody')
|
request_body = post_descriptor.get('requestBody')
|
||||||
|
|
||||||
self.parameters = {p['name']: p for p in parameters}
|
self.parameters = {p['name']: p for p in parameters}
|
||||||
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {path}, descriptor: {json.dumps(descriptor)})'
|
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})'
|
||||||
|
|
||||||
self.body = None
|
self.body = None
|
||||||
self.body_name = None
|
|
||||||
if request_body:
|
if request_body:
|
||||||
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {path}, descriptor: {json.dumps(descriptor)})'
|
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})'
|
||||||
|
|
||||||
|
body_name = 'body'
|
||||||
|
i = 2
|
||||||
|
while body_name in self.parameters:
|
||||||
|
body_name = f'body{i}'
|
||||||
|
i += 1
|
||||||
|
|
||||||
self.body = dict(
|
self.body = dict(
|
||||||
|
name=body_name,
|
||||||
required=request_body['required'],
|
required=request_body['required'],
|
||||||
schema=request_body['content']['application/json']['schema'],
|
schema=request_body['content']['application/json']['schema'],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.body_name = 'body'
|
|
||||||
i = 2
|
|
||||||
while self.body_name in self.parameters:
|
|
||||||
self.body_name = f'body{i}'
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
self.parameters_schema = dict(
|
self.parameters_schema = dict(
|
||||||
type='object',
|
type='object',
|
||||||
properties={
|
properties={
|
||||||
**({
|
**({
|
||||||
self.body_name: self.body['schema']
|
self.body['name']: self.body['schema']
|
||||||
} if self.body else {}),
|
} if self.body else {}),
|
||||||
**{
|
**{
|
||||||
name: param['schema']
|
name: param['schema']
|
||||||
|
@ -46,14 +47,14 @@ class OpenAPIMethod:
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
components=catalog.get('components'),
|
components=catalog.get('components'),
|
||||||
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body_name] if self.body and self.body['required'] else [])
|
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
if self.body:
|
if self.body:
|
||||||
body = kwargs.pop(self.body_name, None)
|
body = kwargs.pop(self.body['name'], None)
|
||||||
if self.body['required']:
|
if self.body['required']:
|
||||||
assert body is not None, f'Missing required body parameter: {self.body_name}'
|
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
|
||||||
else:
|
else:
|
||||||
body = None
|
body = None
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,9 @@ class Duration(BaseModel):
|
||||||
months: Optional[int] = None
|
months: Optional[int] = None
|
||||||
years: Optional[int] = None
|
years: Optional[int] = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def get_total_seconds(self) -> int:
|
def get_total_seconds(self) -> int:
|
||||||
return sum([
|
return sum([
|
||||||
|
@ -29,6 +32,10 @@ class Duration(BaseModel):
|
||||||
class WaitForDuration(BaseModel):
|
class WaitForDuration(BaseModel):
|
||||||
duration: Duration
|
duration: Duration
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
sys.stderr.write(f"Waiting for {self.duration}...\n")
|
||||||
|
time.sleep(self.duration.get_total_seconds)
|
||||||
|
|
||||||
class WaitForDate(BaseModel):
|
class WaitForDate(BaseModel):
|
||||||
until: date
|
until: date
|
||||||
|
|
||||||
|
@ -43,7 +50,7 @@ class WaitForDate(BaseModel):
|
||||||
|
|
||||||
days, seconds = time_diff.days, time_diff.seconds
|
days, seconds = time_diff.days, time_diff.seconds
|
||||||
|
|
||||||
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
|
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
|
||||||
time.sleep(days * 86400 + seconds)
|
time.sleep(days * 86400 + seconds)
|
||||||
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
||||||
|
|
||||||
|
@ -67,8 +74,8 @@ class StandardTools:
|
||||||
return _for()
|
return _for()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def say_out_loud(something: str) -> str:
|
def say_out_loud(something: str) -> None:
|
||||||
"""
|
"""
|
||||||
Just says something. Used to say each thought out loud
|
Just says something. Used to say each thought out loud
|
||||||
"""
|
"""
|
||||||
return subprocess.check_call(["say", something])
|
subprocess.check_call(["say", something])
|
||||||
|
|
|
@ -9,8 +9,10 @@ def load_source_as_module(source):
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, source)
|
spec = importlib.util.spec_from_file_location(module_name, source)
|
||||||
|
assert spec, f'Failed to load {source} as module'
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
|
assert spec.loader, f'{source} spec has no loader'
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -29,7 +31,7 @@ def collect_functions(module):
|
||||||
if k == k.capitalize():
|
if k == k.capitalize():
|
||||||
continue
|
continue
|
||||||
v = getattr(module, k)
|
v = getattr(module, k)
|
||||||
if not callable(v) or isinstance(v, Type):
|
if not callable(v) or isinstance(v, type):
|
||||||
continue
|
continue
|
||||||
if not hasattr(v, '__annotations__'):
|
if not hasattr(v, '__annotations__'):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -55,9 +55,9 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item
|
||||||
|
|
||||||
|
|
||||||
class BuiltinRule:
|
class BuiltinRule:
|
||||||
def __init__(self, content: str, deps: list = None):
|
def __init__(self, content: str, deps: List[str]):
|
||||||
self.content = content
|
self.content = content
|
||||||
self.deps = deps or []
|
self.deps = deps
|
||||||
|
|
||||||
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
|
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class SchemaConverter:
|
||||||
|
|
||||||
def _format_literal(self, literal):
|
def _format_literal(self, literal):
|
||||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
lambda m: GRAMMAR_LITERAL_ESCAPES[m.group(0)], literal
|
||||||
)
|
)
|
||||||
return f'"{escaped}"'
|
return f'"{escaped}"'
|
||||||
|
|
||||||
|
@ -157,13 +157,13 @@ class SchemaConverter:
|
||||||
self._rules[key] = rule
|
self._rules[key] = rule
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def resolve_refs(self, schema: dict, url: str):
|
def resolve_refs(self, schema: Any, url: str):
|
||||||
'''
|
'''
|
||||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||||
replacing $ref with absolute reference URL and populating self._refs with the
|
replacing $ref with absolute reference URL and populating self._refs with the
|
||||||
respective referenced (sub)schema dictionaries.
|
respective referenced (sub)schema dictionaries.
|
||||||
'''
|
'''
|
||||||
def visit(n: dict):
|
def visit(n: Any):
|
||||||
if isinstance(n, list):
|
if isinstance(n, list):
|
||||||
return [visit(x) for x in n]
|
return [visit(x) for x in n]
|
||||||
elif isinstance(n, dict):
|
elif isinstance(n, dict):
|
||||||
|
@ -223,7 +223,7 @@ class SchemaConverter:
|
||||||
|
|
||||||
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||||
pattern = pattern[1:-1]
|
pattern = pattern[1:-1]
|
||||||
sub_rule_ids = {}
|
sub_rule_ids: Dict[str, str] = {}
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
length = len(pattern)
|
length = len(pattern)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from pydantic import BaseModel, Json, TypeAdapter
|
from pydantic import BaseModel, Json, TypeAdapter
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
|
@ -16,7 +16,7 @@ class Message(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
tool_call_id: Optional[str] = None
|
tool_call_id: Optional[str] = None
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
tool_calls: Optional[list[ToolCall]] = None
|
tool_calls: Optional[List[ToolCall]] = None
|
||||||
|
|
||||||
class ToolFunction(BaseModel):
|
class ToolFunction(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@ -29,7 +29,7 @@ class Tool(BaseModel):
|
||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
class ResponseFormat(BaseModel):
|
||||||
type: Literal["json_object"]
|
type: Literal["json_object"]
|
||||||
schema: Optional[Dict] = None
|
schema: Optional[Json[Any]] = None # type: ignore
|
||||||
|
|
||||||
class LlamaCppParams(BaseModel):
|
class LlamaCppParams(BaseModel):
|
||||||
n_predict: Optional[int] = None
|
n_predict: Optional[int] = None
|
||||||
|
@ -56,8 +56,8 @@ class LlamaCppParams(BaseModel):
|
||||||
|
|
||||||
class ChatCompletionRequest(LlamaCppParams):
|
class ChatCompletionRequest(LlamaCppParams):
|
||||||
model: str
|
model: str
|
||||||
tools: Optional[list[Tool]] = None
|
tools: Optional[List[Tool]] = None
|
||||||
messages: list[Message] = None
|
messages: Optional[List[Message]] = None
|
||||||
prompt: Optional[str] = None
|
prompt: Optional[str] = None
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams):
|
||||||
class Choice(BaseModel):
|
class Choice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: Message
|
message: Message
|
||||||
logprobs: Optional[Json] = None
|
logprobs: Optional[Json[Any]] = None
|
||||||
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
|
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
|
||||||
|
|
||||||
class Usage(BaseModel):
|
class Usage(BaseModel):
|
||||||
|
@ -84,7 +84,7 @@ class ChatCompletionResponse(BaseModel):
|
||||||
object: Literal["chat.completion"]
|
object: Literal["chat.completion"]
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
choices: list[Choice]
|
choices: List[Choice]
|
||||||
usage: Usage
|
usage: Usage
|
||||||
system_fingerprint: str
|
system_fingerprint: str
|
||||||
error: Optional[CompletionError] = None
|
error: Optional[CompletionError] = None
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
from pydantic import Json
|
from pydantic import Json
|
||||||
|
|
||||||
from examples.openai.api import LlamaCppParams
|
from examples.openai.api import LlamaCppParams
|
||||||
|
@ -9,4 +9,4 @@ class LlamaCppServerCompletionRequest(LlamaCppParams):
|
||||||
cache_prompt: Optional[bool] = None
|
cache_prompt: Optional[bool] = None
|
||||||
|
|
||||||
grammar: Optional[str] = None
|
grammar: Optional[str] = None
|
||||||
json_schema: Optional[Json] = None
|
json_schema: Optional[Json[Any]] = None
|
||||||
|
|
|
@ -6,12 +6,12 @@ from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Any, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, Json
|
||||||
|
|
||||||
from examples.json_schema_to_grammar import SchemaConverter
|
from examples.json_schema_to_grammar import SchemaConverter
|
||||||
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
||||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
|
||||||
from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
||||||
|
|
||||||
# _THOUGHT_KEY = "thought"
|
# _THOUGHT_KEY = "thought"
|
||||||
|
@ -65,7 +65,7 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def potentially_supports_parallel_calls(self) -> bool:
|
def potentially_supports_parallel_calls(self) -> bool:
|
||||||
return self.formats_tool_result and self.formats_tool_name
|
return bool(self.formats_tool_result and self.formats_tool_name)
|
||||||
|
|
||||||
def __init__(self, template: str, eos_token: str, bos_token: str):
|
def __init__(self, template: str, eos_token: str, bos_token: str):
|
||||||
super().__init__(template=template, eos_token=eos_token, bos_token=bos_token)
|
super().__init__(template=template, eos_token=eos_token, bos_token=bos_token)
|
||||||
|
@ -161,7 +161,7 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_huggingface(model_id: str):
|
def from_huggingface(model_id: str):
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer # type: ignore
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(model_id)
|
tokenizer = LlamaTokenizer.from_pretrained(model_id)
|
||||||
return ChatTemplate(
|
return ChatTemplate(
|
||||||
template = tokenizer.chat_template or tokenizer.default_chat_template,
|
template = tokenizer.chat_template or tokenizer.default_chat_template,
|
||||||
|
@ -170,7 +170,7 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||||
result = self._template.render(
|
result = self._template.render(
|
||||||
messages=messages,
|
messages=[messages.model_dump() for messages in messages],
|
||||||
eos_token=self.eos_token,
|
eos_token=self.eos_token,
|
||||||
bos_token='' if omit_bos else self.bos_token,
|
bos_token='' if omit_bos else self.bos_token,
|
||||||
raise_exception=raise_exception,
|
raise_exception=raise_exception,
|
||||||
|
@ -180,7 +180,7 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
class ChatHandlerArgs(BaseModel):
|
class ChatHandlerArgs(BaseModel):
|
||||||
chat_template: ChatTemplate
|
chat_template: ChatTemplate
|
||||||
response_schema: Optional[dict] = None
|
response_schema: Optional[Json[Any]] = None
|
||||||
tools: Optional[list[Tool]] = None
|
tools: Optional[list[Tool]] = None
|
||||||
|
|
||||||
class ChatHandler(ABC):
|
class ChatHandler(ABC):
|
||||||
|
@ -199,9 +199,9 @@ class ChatHandler(ABC):
|
||||||
assert system_prompt.role == "system"
|
assert system_prompt.role == "system"
|
||||||
# TODO: add to last system message, or create a new one just before the last user message
|
# TODO: add to last system message, or create a new one just before the last user message
|
||||||
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
||||||
if system_message is not None:
|
if system_message:
|
||||||
(i, m) = system_message
|
(i, m) = system_message
|
||||||
return messages[:i] + [Message(role="system", content=system_prompt.content + '\n' + m.content)] + messages[i+1:]
|
return messages[:i] + [Message(role="system", content=(system_prompt.content + '\n' if system_prompt.content else '') + (m.content or ''))] + messages[i+1:]
|
||||||
else:
|
else:
|
||||||
return [system_prompt] + messages
|
return [system_prompt] + messages
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ class ChatHandler(ABC):
|
||||||
if self.args.chat_template.expects_strict_user_assistant_alternance:
|
if self.args.chat_template.expects_strict_user_assistant_alternance:
|
||||||
new_messages=[]
|
new_messages=[]
|
||||||
current_role = 'user'
|
current_role = 'user'
|
||||||
current_content = []
|
current_content: list[str] = []
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
nonlocal current_content
|
nonlocal current_content
|
||||||
|
@ -311,24 +311,24 @@ class ChatHandler(ABC):
|
||||||
messages = new_messages
|
messages = new_messages
|
||||||
|
|
||||||
# JSON!
|
# JSON!
|
||||||
messages = [m.model_dump() for m in messages]
|
# messages = [m.model_dump() for m in messages]
|
||||||
|
|
||||||
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
if self.args.chat_template.expects_stringified_function_arguments:
|
if self.args.chat_template.expects_stringified_function_arguments:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
Message(**{
|
||||||
**m,
|
**m.model_dump(),
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
ToolCall(**{
|
||||||
**tc,
|
**tc.model_dump(),
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc["function"]["name"],
|
"name": tc.function.name,
|
||||||
"arguments": json.dumps(tc["function"]["arguments"]),
|
"arguments": tc.function.arguments,
|
||||||
}
|
|
||||||
}
|
|
||||||
for tc in m["tool_calls"]
|
|
||||||
] if m.get("tool_calls") else None
|
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
for tc in m.tool_calls
|
||||||
|
] if m.tool_calls else None
|
||||||
|
})
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -364,7 +364,7 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
tool_rules = []
|
tool_rules = []
|
||||||
for tool in self.args.tools:
|
for tool in self.args.tools or []:
|
||||||
|
|
||||||
parameters_schema = tool.function.parameters
|
parameters_schema = tool.function.parameters
|
||||||
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
|
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
|
||||||
|
@ -416,7 +416,7 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
return Message(role="assistant", content=s)
|
return Message(role="assistant", content=s)
|
||||||
else:
|
else:
|
||||||
content = []
|
content: list[str] = []
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
|
@ -431,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(**fc)))
|
function=FunctionCall(**fc)))
|
||||||
|
|
||||||
content = '\n'.join(content).strip()
|
content_str = '\n'.join(content).strip()
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
|
||||||
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
|
@ -444,7 +444,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
role="system",
|
role="system",
|
||||||
content=template.replace(
|
content=template.replace(
|
||||||
'{tools}',
|
'{tools}',
|
||||||
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
|
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in (self.args.tools or [])),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -456,11 +456,11 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||||
if path not in sys.path: sys.path.insert(0, path)
|
if path not in sys.path: sys.path.insert(0, path)
|
||||||
try:
|
try:
|
||||||
from examples.openai.hermes_function_calling.prompter import PromptManager
|
from examples.openai.hermes_function_calling.prompter import PromptManager # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
|
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
|
||||||
|
|
||||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools])
|
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[tool.model_dump_json() for tool in args.tools or []])
|
||||||
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
||||||
self.output_format_prompt = Message(**prompt[0])
|
self.output_format_prompt = Message(**prompt[0])
|
||||||
|
|
||||||
|
@ -471,7 +471,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
self.output_format_prompt = Message(
|
self.output_format_prompt = Message(
|
||||||
role="system",
|
role="system",
|
||||||
content= '// Supported function definitions that should be called when necessary.\n' +
|
content= '// Supported function definitions that should be called when necessary.\n' +
|
||||||
_tools_typescript_signatures(args.tools)
|
_tools_typescript_signatures(args.tools or [])
|
||||||
)
|
)
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
@ -481,7 +481,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
||||||
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
||||||
converter._format_literal('\n'))
|
converter._format_literal('\n'))
|
||||||
for i, tool in enumerate(self.args.tools)
|
for i, tool in enumerate(self.args.tools or [])
|
||||||
]
|
]
|
||||||
|
|
||||||
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
|
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
|
||||||
|
@ -583,7 +583,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
||||||
response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response')
|
response_schema = converter.resolve_refs(args.response_schema or {"type": "string"}, 'response')
|
||||||
tool_parameter_schemas = {
|
tool_parameter_schemas = {
|
||||||
tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name)
|
tool.function.name: converter.resolve_refs(tool.function.parameters, tool.function.name)
|
||||||
for tool in self.args.tools
|
for tool in self.args.tools or []
|
||||||
}
|
}
|
||||||
# sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n")
|
# sys.stderr.write(f"# RESOLVED RESPONSE SCHEMA: {json.dumps(response_schema, indent=2)}\n")
|
||||||
# sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n")
|
# sys.stderr.write(f"# RESOLVED TOOL PARAMETER SCHEMA: {json.dumps(tool_parameter_schemas, indent=2)}\n")
|
||||||
|
@ -614,7 +614,7 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
||||||
content='\n'.join([
|
content='\n'.join([
|
||||||
'You are a function calling AI model.',
|
'You are a function calling AI model.',
|
||||||
'Here are the tools available:',
|
'Here are the tools available:',
|
||||||
_tools_schema_signatures(self.args.tools, indent=2),
|
_tools_schema_signatures(self.args.tools or [], indent=2),
|
||||||
# _tools_typescript_signatures(self.args.tools),
|
# _tools_typescript_signatures(self.args.tools),
|
||||||
_please_respond_with_schema(
|
_please_respond_with_schema(
|
||||||
_make_bespoke_schema(
|
_make_bespoke_schema(
|
||||||
|
@ -716,10 +716,10 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
|
||||||
elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
elif tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||||
return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls)
|
return Hermes2ProToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
|
raise ValueError(f"Unsupported tool call style: {tool_style}")
|
||||||
|
|
||||||
# os.environ.get('NO_TS')
|
# os.environ.get('NO_TS')
|
||||||
def _please_respond_with_schema(schema: dict) -> str:
|
def _please_respond_with_schema(schema: Json[Any]) -> str:
|
||||||
sig = json.dumps(schema, indent=2)
|
sig = json.dumps(schema, indent=2)
|
||||||
# _ts_converter = SchemaToTypeScriptConverter()
|
# _ts_converter = SchemaToTypeScriptConverter()
|
||||||
# # _ts_converter.resolve_refs(schema, 'schema')
|
# # _ts_converter.resolve_refs(schema, 'schema')
|
||||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
||||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
from examples.openai.gguf_kvs import GGUFKeyValues, Keys # type: ignore
|
||||||
from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage
|
from examples.openai.api import ChatCompletionResponse, Choice, ChatCompletionRequest, Usage
|
||||||
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
|
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
|
||||||
|
|
||||||
|
@ -21,12 +21,12 @@ def generate_id(prefix):
|
||||||
return f"{prefix}{random.randint(0, 1 << 32)}"
|
return f"{prefix}{random.randint(0, 1 << 32)}"
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||||
template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf',
|
template_hf_model_id_fallback: Annotated[Optional[str], typer.Option(help="If the GGUF model does not contain a chat template, get it from this HuggingFace tokenizer")] = 'meta-llama/Llama-2-7b-chat-hf',
|
||||||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
port: int = 8080,
|
port: int = 8080,
|
||||||
parallel_calls: Optional[bool] = False,
|
parallel_calls: bool = False,
|
||||||
style: Optional[ToolsPromptStyle] = None,
|
style: Optional[ToolsPromptStyle] = None,
|
||||||
auth: Optional[str] = None,
|
auth: Optional[str] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
@ -39,10 +39,11 @@ def main(
|
||||||
|
|
||||||
if endpoint:
|
if endpoint:
|
||||||
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
||||||
|
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when using an endpoint"
|
||||||
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
metadata = GGUFKeyValues(model)
|
metadata = GGUFKeyValues(Path(model))
|
||||||
|
|
||||||
if not context_length:
|
if not context_length:
|
||||||
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
||||||
|
@ -51,6 +52,7 @@ def main(
|
||||||
chat_template = ChatTemplate.from_gguf(metadata)
|
chat_template = ChatTemplate.from_gguf(metadata)
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
sys.stderr.write(f"# WARNING: Model does not contain a chat template, fetching it from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
||||||
|
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when the model does not contain a chat template"
|
||||||
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -93,9 +95,8 @@ def main(
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = chat_request.messages
|
prompt = chat_handler.render_prompt(chat_request.messages) if chat_request.messages else chat_request.prompt
|
||||||
|
assert prompt is not None, "One of prompt or messages field is required"
|
||||||
prompt = chat_handler.render_prompt(messages)
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from typing import Any, List, Set, Tuple, Union
|
from typing import Any, Dict, List, Set, Tuple, Union
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from pydantic import Json
|
||||||
|
|
||||||
class SchemaToTypeScriptConverter:
|
class SchemaToTypeScriptConverter:
|
||||||
# TODO: comments for arguments!
|
# TODO: comments for arguments!
|
||||||
# // Get the price of a particular car model
|
# // Get the price of a particular car model
|
||||||
|
@ -15,17 +17,18 @@ class SchemaToTypeScriptConverter:
|
||||||
# location: string,
|
# location: string,
|
||||||
# }) => any;
|
# }) => any;
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, allow_fetch: bool = True):
|
||||||
self._refs = {}
|
self._refs: Dict[str, Json[Any]] = {}
|
||||||
self._refs_being_resolved = set()
|
self._refs_being_resolved: Set[str] = set()
|
||||||
|
self._allow_fetch = allow_fetch
|
||||||
|
|
||||||
def resolve_refs(self, schema: dict, url: str):
|
def resolve_refs(self, schema: Json[Any], url: str):
|
||||||
'''
|
'''
|
||||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||||
replacing $ref with absolute reference URL and populating self._refs with the
|
replacing $ref with absolute reference URL and populating self._refs with the
|
||||||
respective referenced (sub)schema dictionaries.
|
respective referenced (sub)schema dictionaries.
|
||||||
'''
|
'''
|
||||||
def visit(n: dict):
|
def visit(n: Json[Any]):
|
||||||
if isinstance(n, list):
|
if isinstance(n, list):
|
||||||
return [visit(x) for x in n]
|
return [visit(x) for x in n]
|
||||||
elif isinstance(n, dict):
|
elif isinstance(n, dict):
|
||||||
|
@ -64,7 +67,7 @@ class SchemaToTypeScriptConverter:
|
||||||
return n
|
return n
|
||||||
return visit(schema)
|
return visit(schema)
|
||||||
|
|
||||||
def _desc_comment(self, schema: dict):
|
def _desc_comment(self, schema: Json[Any]):
|
||||||
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
||||||
return f'// {desc}\n' if desc else ''
|
return f'// {desc}\n' if desc else ''
|
||||||
|
|
||||||
|
@ -78,11 +81,11 @@ class SchemaToTypeScriptConverter:
|
||||||
f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
||||||
for prop_name, prop_schema in properties
|
for prop_name, prop_schema in properties
|
||||||
] + (
|
] + (
|
||||||
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
|
[f"{self._desc_comment(additional_properties) if isinstance(additional_properties, dict) else ''}[key: string]: {self.visit(additional_properties)}"]
|
||||||
if additional_properties is not None else []
|
if additional_properties is not None else []
|
||||||
)) + "\n}"
|
)) + "\n}"
|
||||||
|
|
||||||
def visit(self, schema: dict):
|
def visit(self, schema: Json[Any]):
|
||||||
def print_constant(v):
|
def print_constant(v):
|
||||||
return json.dumps(v)
|
return json.dumps(v)
|
||||||
|
|
||||||
|
@ -90,7 +93,7 @@ class SchemaToTypeScriptConverter:
|
||||||
schema_format = schema.get('format')
|
schema_format = schema.get('format')
|
||||||
|
|
||||||
if 'oneOf' in schema or 'anyOf' in schema:
|
if 'oneOf' in schema or 'anyOf' in schema:
|
||||||
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf'))
|
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf') or [])
|
||||||
|
|
||||||
elif isinstance(schema_type, list):
|
elif isinstance(schema_type, list):
|
||||||
return '|'.join(self.visit({'type': t}) for t in schema_type)
|
return '|'.join(self.visit({'type': t}) for t in schema_type)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue