agent/openai:nits
This commit is contained in:
parent
ce2fb0155f
commit
ea34bd3e5c
10 changed files with 72 additions and 145 deletions
|
@ -128,7 +128,7 @@ def main(
|
|||
max_iterations: Optional[int] = 10,
|
||||
std_tools: Optional[bool] = False,
|
||||
auth: Optional[str] = None,
|
||||
allow_parallel_calls: Optional[bool] = False,
|
||||
parallel_calls: Optional[bool] = True,
|
||||
verbose: bool = False,
|
||||
|
||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
|
@ -174,7 +174,7 @@ def main(
|
|||
"python", "-m", "examples.openai.server",
|
||||
"--model", model,
|
||||
*(['--verbose'] if verbose else []),
|
||||
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
|
||||
*(['--parallel-calls'] if parallel_calls else []),
|
||||
*(['--context-length={context_length}'] if context_length else []),
|
||||
*([])
|
||||
]
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
import atexit
|
||||
from datetime import date
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
import time
|
||||
import typer
|
||||
from pydantic import BaseModel, Json, TypeAdapter
|
||||
from annotated_types import MinLen
|
||||
from typing import Annotated, Callable, List, Union, Literal, Optional, Type, get_args, get_origin
|
||||
import json, requests
|
||||
from typing import Union, Optional
|
||||
|
||||
class Duration(BaseModel):
|
||||
seconds: Optional[int] = None
|
||||
|
|
|
@ -1,28 +1,12 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel, Json
|
||||
from pydantic import Json
|
||||
|
||||
class LlamaCppServerCompletionRequest(BaseModel):
|
||||
from examples.openai.api import LlamaCppParams
|
||||
|
||||
class LlamaCppServerCompletionRequest(LlamaCppParams):
|
||||
prompt: str
|
||||
stream: Optional[bool] = None
|
||||
cache_prompt: Optional[bool] = None
|
||||
n_predict: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
min_p: Optional[float] = None
|
||||
tfs_z: Optional[float] = None
|
||||
typical_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
dynatemp_range: Optional[float] = None
|
||||
dynatemp_exponent: Optional[float] = None
|
||||
repeat_last_n: Optional[int] = None
|
||||
repeat_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
mirostat: Optional[bool] = None
|
||||
mirostat_tau: Optional[float] = None
|
||||
mirostat_eta: Optional[float] = None
|
||||
penalize_nl: Optional[bool] = None
|
||||
n_keep: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
grammar: Optional[str] = None
|
||||
json_schema: Optional[Json] = None
|
|
@ -1,15 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
import jinja2
|
||||
import json
|
||||
from pathlib import Path
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
# from typeguard import typechecked
|
||||
|
||||
from examples.json_schema_to_grammar import SchemaConverter
|
||||
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
||||
|
@ -129,8 +127,6 @@ class ChatTemplate(BaseModel):
|
|||
eos_token = tokenizer.eos_token)
|
||||
|
||||
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||
# sys.stderr.write(f'# strict_user_assistant_alternation={self._strict_user_assistant_alternation}\n')
|
||||
# sys.stderr.write(f'# messages=' + "\n".join(json.dumps(m.model_dump(), indent=2) for m in messages) + '\n')
|
||||
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
||||
new_messages=[]
|
||||
i = 0
|
||||
|
@ -161,7 +157,6 @@ class ChatTemplate(BaseModel):
|
|||
i += 1
|
||||
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||
messages = new_messages
|
||||
# print(f'messages={messages}')
|
||||
|
||||
result = self._template.render(
|
||||
messages=messages,
|
||||
|
@ -170,7 +165,6 @@ class ChatTemplate(BaseModel):
|
|||
raise_exception=raise_exception,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
# sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
|
||||
return result
|
||||
|
||||
class ChatHandlerArgs(BaseModel):
|
||||
|
@ -206,12 +200,11 @@ class NoToolsChatHandler(ChatHandler):
|
|||
self.output_format_prompt = None
|
||||
self.grammar = None
|
||||
|
||||
# @typechecked
|
||||
def parse(self, s: str) -> Optional[Message]:
|
||||
return Message(role="assistant", content=s)
|
||||
|
||||
class ToolCallTagsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
|
||||
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
|
||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
|
@ -253,22 +246,10 @@ class ToolCallTagsChatHandler(ChatHandler):
|
|||
converter._add_rule(
|
||||
'root',
|
||||
# tool_call_rule)
|
||||
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
|
||||
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if parallel_calls \
|
||||
else f'{content_rule}* {tool_call_rule}?')
|
||||
self.grammar = converter.format_grammar()
|
||||
|
||||
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
||||
# # OR a tool-call message respecting the schema of any of the tools
|
||||
# converter._add_rule(
|
||||
# "root",
|
||||
# converter._format_literal(prefix) + " (" +
|
||||
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
|
||||
# converter._format_literal("<tool_call>") + " (" +
|
||||
# ' | '.join(tool_rules) +
|
||||
# ") " + converter._format_literal("</tool_call>") +
|
||||
# ")") # + converter._format_literal(suffix))
|
||||
|
||||
# @typechecked
|
||||
def parse(self, s: str) -> Optional[Message]:
|
||||
s = self.args.chat_template.strip_suffix(s)
|
||||
|
||||
|
@ -298,17 +279,10 @@ class ToolCallTagsChatHandler(ChatHandler):
|
|||
content = '\n'.join(content).strip()
|
||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
||||
|
||||
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
||||
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
||||
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
||||
# return None
|
||||
# else:
|
||||
# return Message(role="assistant", content=s)
|
||||
|
||||
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
|
||||
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
|
||||
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
|
||||
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
|
||||
assert '{tools}' in template, 'Template must contain "{tools}"'
|
||||
|
||||
self.output_format_prompt = Message(
|
||||
|
@ -320,8 +294,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
|||
)
|
||||
|
||||
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
|
||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
|
||||
|
||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||
|
@ -331,12 +305,12 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
|||
except ImportError:
|
||||
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
|
||||
|
||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
|
||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in args.tools])
|
||||
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
||||
self.output_format_prompt = Message(**prompt[0])
|
||||
|
||||
class FunctionaryToolsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
|
||||
# Only allowing a single tool call at a time for now.
|
||||
|
@ -355,17 +329,6 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
|||
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
||||
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
||||
converter._format_literal('\n'))
|
||||
# converter.visit(
|
||||
# dict(
|
||||
# type="object",
|
||||
# properties=dict(
|
||||
# name=dict(const=tool.function.name),
|
||||
# arguments=tool.function.parameters,
|
||||
# ),
|
||||
# required=['name', 'arguments']
|
||||
# ),
|
||||
# f'{tool.function.name}-tool-call'
|
||||
# )
|
||||
for i, tool in enumerate(self.args.tools)
|
||||
]
|
||||
|
||||
|
@ -378,30 +341,15 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
|||
tool_call_without_start_rule = converter._add_rule(
|
||||
'tool_call_without_start',
|
||||
' | '.join(tool_rules))
|
||||
# + ' ' +
|
||||
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
|
||||
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
|
||||
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
|
||||
converter._add_rule(
|
||||
'root',
|
||||
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
||||
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
|
||||
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if parallel_calls \
|
||||
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
|
||||
|
||||
self.grammar = converter.format_grammar()
|
||||
# converter._add_rule(
|
||||
# "root",
|
||||
# converter._format_literal(prefix) + " (" +
|
||||
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
|
||||
# (' | '.join(
|
||||
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
|
||||
# converter.visit(tool.function.parameters, tool.function.name + '-args')
|
||||
# for tool in tools
|
||||
# )) +
|
||||
# ") " +
|
||||
# ")") # + converter._format_literal(suffix))
|
||||
|
||||
# @typechecked
|
||||
def parse(self, s: str) -> Optional[Message]:
|
||||
s = self.args.chat_template.strip_suffix(s)
|
||||
|
||||
|
@ -433,7 +381,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
|||
content = '\n'.join(text_content).strip()
|
||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
||||
|
||||
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
|
||||
def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -453,7 +401,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
|
|||
# "const": "tool_calls"
|
||||
# },
|
||||
"tool_calls": {
|
||||
"prefixItems": tool_call_schema if allow_parallel_calls \
|
||||
"prefixItems": tool_call_schema if parallel_calls \
|
||||
else [tool_call_schema],
|
||||
}
|
||||
},
|
||||
|
@ -474,7 +422,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
|
|||
}
|
||||
|
||||
class BespokeToolsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
|
||||
# args.response_schema = args.response_schema or {}
|
||||
|
@ -497,7 +445,7 @@ class BespokeToolsChatHandler(ChatHandler):
|
|||
for tool in self.args.tools
|
||||
]
|
||||
},
|
||||
allow_parallel_calls=allow_parallel_calls,
|
||||
parallel_calls=parallel_calls,
|
||||
),
|
||||
'',
|
||||
)
|
||||
|
@ -525,13 +473,12 @@ class BespokeToolsChatHandler(ChatHandler):
|
|||
},
|
||||
"required": ["name", "arguments"]
|
||||
},
|
||||
allow_parallel_calls=allow_parallel_calls,
|
||||
parallel_calls=parallel_calls,
|
||||
)
|
||||
),
|
||||
])
|
||||
)
|
||||
|
||||
# @typechecked
|
||||
def parse(self, s: str) -> Optional[Message]:
|
||||
s = self.args.chat_template.strip_suffix(s)
|
||||
try:
|
||||
|
@ -579,19 +526,19 @@ _LONG_TEMPLATE='\n'.join([
|
|||
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
|
||||
])
|
||||
|
||||
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
|
||||
def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool) -> ChatHandler:
|
||||
if not args.tools:
|
||||
return NoToolsChatHandler(args)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||
return FunctionaryToolsChatHandler(args, allow_parallel_calls=False)
|
||||
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
|
||||
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
|
||||
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, parallel_calls=parallel_calls)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
|
||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
|
||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
|
||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls, escapes_underscores=True)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
||||
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
|
||||
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||
return Hermes2ProToolsChatHandler(args)
|
||||
else:
|
||||
|
|
|
@ -31,7 +31,7 @@ def main(
|
|||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 8080,
|
||||
allow_parallel_calls: Optional[bool] = False,
|
||||
parallel_calls: Optional[bool] = True,
|
||||
auth: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
context_length: Optional[int] = None,
|
||||
|
@ -92,7 +92,7 @@ def main(
|
|||
|
||||
chat_handler = get_chat_handler(
|
||||
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
|
||||
allow_parallel_calls=allow_parallel_calls
|
||||
parallel_calls=parallel_calls
|
||||
)
|
||||
|
||||
messages = chat_request.messages
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue