server.py: pass all request options, comments in ts sigs, render tool calls
This commit is contained in:
parent
63a384deaf
commit
5f3de16116
4 changed files with 107 additions and 41 deletions
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
from pydantic import BaseModel, Json
|
from pydantic import BaseModel, Json, TypeAdapter
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@ -31,10 +31,33 @@ class ResponseFormat(BaseModel):
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
tools: Optional[list[Tool]] = None
|
tools: Optional[list[Tool]] = None
|
||||||
messages: list[Message]
|
messages: list[Message] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
temperature: float = 1.0
|
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
cache_prompt: Optional[bool] = None
|
||||||
|
n_predict: Optional[int] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
min_p: Optional[float] = None
|
||||||
|
tfs_z: Optional[float] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
temperature: float = 1.0
|
||||||
|
dynatemp_range: Optional[float] = None
|
||||||
|
dynatemp_exponent: Optional[float] = None
|
||||||
|
repeat_last_n: Optional[int] = None
|
||||||
|
repeat_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presense_penalty: Optional[float] = None
|
||||||
|
mirostat: Optional[bool] = None
|
||||||
|
mirostat_tau: Optional[float] = None
|
||||||
|
mirostat_eta: Optional[float] = None
|
||||||
|
penalize_nl: Optional[bool] = None
|
||||||
|
n_keep: Optional[int] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
n_probs: Optional[int] = None
|
||||||
|
min_keep: Optional[int] = None
|
||||||
|
|
||||||
class Choice(BaseModel):
|
class Choice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
|
|
|
@ -41,7 +41,7 @@ class ChatFormat:
|
||||||
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
||||||
if system_message is not None:
|
if system_message is not None:
|
||||||
(i, m) = system_message
|
(i, m) = system_message
|
||||||
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
return messages[:i] + [Message(role="system", content=system_prompt.content + '\n' + m.content)] + messages[i+1:]
|
||||||
else:
|
else:
|
||||||
return [system_prompt] + messages
|
return [system_prompt] + messages
|
||||||
|
|
||||||
|
@ -63,8 +63,16 @@ class ChatFormat:
|
||||||
assert messages[i+1].role == 'user'
|
assert messages[i+1].role == 'user'
|
||||||
new_messages.append(Message(
|
new_messages.append(Message(
|
||||||
role="user",
|
role="user",
|
||||||
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'))
|
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'
|
||||||
|
))
|
||||||
i += 2
|
i += 2
|
||||||
|
elif messages[i].role == 'assistant' and messages[i].tool_calls and messages[i].content:
|
||||||
|
tc = '\n'.join(f'<tool_call>{json.dumps(tc.model_dump())}</tool_call>' for tc in messages[i].tool_calls)
|
||||||
|
new_messages.append(Message(
|
||||||
|
role="assistant",
|
||||||
|
content=f'{messages[i].content}\n{tc}'
|
||||||
|
))
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
new_messages.append(messages[i])
|
new_messages.append(messages[i])
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -72,13 +80,15 @@ class ChatFormat:
|
||||||
messages = new_messages
|
messages = new_messages
|
||||||
# print(f'messages={messages}')
|
# print(f'messages={messages}')
|
||||||
|
|
||||||
return self.template.render(
|
result = self.template.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
eos_token=self.eos_token,
|
eos_token=self.eos_token,
|
||||||
bos_token='' if omit_bos else self.bos_token,
|
bos_token='' if omit_bos else self.bos_token,
|
||||||
raise_exception=raise_exception,
|
raise_exception=raise_exception,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
)
|
)
|
||||||
|
sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
|
||||||
|
return result
|
||||||
|
|
||||||
# While the API will be usable with a generic tools usage like OpenAI,
|
# While the API will be usable with a generic tools usage like OpenAI,
|
||||||
# (see https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models),
|
# (see https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models),
|
||||||
|
@ -120,38 +130,29 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M
|
||||||
return Message(
|
return Message(
|
||||||
role="system",
|
role="system",
|
||||||
content='\n'.join([
|
content='\n'.join([
|
||||||
'''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
|
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
|
||||||
'''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
|
'''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
|
||||||
'''<tools>''',
|
'''<tools>''',
|
||||||
*(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
|
_tools_typescript_signatures(tools),
|
||||||
|
# _tools_schema_signatures(tools, indent=indent),
|
||||||
'''</tools>''',
|
'''</tools>''',
|
||||||
'',
|
'',
|
||||||
'''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
|
# '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
|
||||||
'',
|
# '',
|
||||||
# '''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
# '''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
||||||
'''To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
'''To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
||||||
'''<tool_call>''',
|
'''<tool_call>''',
|
||||||
'''{"arguments": <args-dict>, "name": <function-name>}''',
|
'''{"name": <function-name>, "arguments": <args-dict>}''',
|
||||||
'''</tool_call>''',
|
'''</tool_call>''',
|
||||||
'''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it.''',
|
# '''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
ts_converter = SchemaToTypeScriptConverter()
|
|
||||||
|
|
||||||
return Message(
|
return Message(
|
||||||
role="system",
|
role="system",
|
||||||
content='\n'.join([
|
content= '// Supported function definitions that should be called when necessary.\n' +
|
||||||
'// Supported function definitions that should be called when necessary.'
|
_tools_typescript_signatures(tools)
|
||||||
'namespace functions {',
|
|
||||||
*[
|
|
||||||
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
|
||||||
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
|
||||||
for tool in tools
|
|
||||||
],
|
|
||||||
'} // namespace functions',
|
|
||||||
])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||||
|
@ -170,6 +171,20 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
|
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
|
||||||
|
|
||||||
|
def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
||||||
|
ts_converter = SchemaToTypeScriptConverter()
|
||||||
|
return 'namespace functions {' + '\n'.join(
|
||||||
|
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
||||||
|
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
||||||
|
for tool in tools
|
||||||
|
) + '} // namespace functions'
|
||||||
|
|
||||||
|
def _tools_schema_signatures(tools: list[Tool], indent=None) -> str:
|
||||||
|
return '\n'.join(
|
||||||
|
json.dumps(tool.model_dump(), indent=indent)
|
||||||
|
for tool in tools
|
||||||
|
)
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
||||||
return style in (
|
return style in (
|
||||||
|
@ -199,6 +214,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||||
|
|
||||||
|
allow_parallel_calls = False
|
||||||
|
|
||||||
def strip_suffix(s: str) -> str:
|
def strip_suffix(s: str) -> str:
|
||||||
if s.endswith(suffix):
|
if s.endswith(suffix):
|
||||||
return s[:-len(suffix)]
|
return s[:-len(suffix)]
|
||||||
|
@ -235,17 +252,19 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
|
|
||||||
tool_call_rule = converter._add_rule(
|
tool_call_rule = converter._add_rule(
|
||||||
'tool_call',
|
'tool_call',
|
||||||
format_literal("<tool_call>") + " (" +
|
format_literal("<tool_call>") + " space (" +
|
||||||
' | '.join(tool_rules) +
|
' | '.join(tool_rules) +
|
||||||
") " + format_literal("</tool_call>"))
|
") space " + format_literal("</tool_call>"))# + ' space')
|
||||||
|
|
||||||
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
||||||
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
||||||
content_rule = converter._add_rule('content', '[^<] | "<" [^t<]? | "<t" [^o<]?')
|
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
|
||||||
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
|
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
'root',
|
'root',
|
||||||
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?')
|
# tool_call_rule)
|
||||||
|
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
|
||||||
|
else f'{content_rule}* {tool_call_rule}?')
|
||||||
|
|
||||||
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
||||||
# # OR a tool-call message respecting the schema of any of the tools
|
# # OR a tool-call message respecting the schema of any of the tools
|
||||||
|
@ -285,7 +304,7 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(**fc)))
|
function=FunctionCall(**fc)))
|
||||||
|
|
||||||
content = '(...)'.join(content).strip()
|
content = '\n'.join(content).strip()
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
||||||
|
|
||||||
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||||
|
@ -338,7 +357,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
'root',
|
'root',
|
||||||
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
||||||
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*')
|
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
|
||||||
|
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
|
||||||
|
|
||||||
# converter._add_rule(
|
# converter._add_rule(
|
||||||
# "root",
|
# "root",
|
||||||
|
|
|
@ -59,8 +59,9 @@ def main(
|
||||||
async def chat_completions(request: Request, chat_request: ChatCompletionRequest):
|
async def chat_completions(request: Request, chat_request: ChatCompletionRequest):
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": request.headers.get("Authorization"),
|
|
||||||
}
|
}
|
||||||
|
if (auth := request.headers.get("Authorization")):
|
||||||
|
headers["Authorization"] = auth
|
||||||
|
|
||||||
if chat_request.response_format is not None:
|
if chat_request.response_format is not None:
|
||||||
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
|
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
|
||||||
|
@ -75,18 +76,31 @@ def main(
|
||||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
||||||
|
|
||||||
# TODO: Test whether the template supports formatting tool_calls
|
# TODO: Test whether the template supports formatting tool_calls
|
||||||
sys.stderr.write(f'\n{grammar}\n\n')
|
|
||||||
|
|
||||||
prompt = chat_format.render(messages, add_generation_prompt=True)
|
prompt = chat_format.render(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
|
||||||
|
sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n')
|
||||||
|
|
||||||
|
data = LlamaCppServerCompletionRequest(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in chat_request.model_dump().items()
|
||||||
|
if k not in (
|
||||||
|
"prompt",
|
||||||
|
"tools",
|
||||||
|
"messages",
|
||||||
|
"response_format",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
prompt=prompt,
|
||||||
|
grammar=grammar,
|
||||||
|
).model_dump()
|
||||||
|
sys.stderr.write(json.dumps(data, indent=2) + "\n")
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{cpp_server_endpoint}/completions",
|
f"{cpp_server_endpoint}/completions",
|
||||||
json=LlamaCppServerCompletionRequest(
|
json=data,
|
||||||
prompt=prompt,
|
|
||||||
stream=chat_request.stream,
|
|
||||||
n_predict=1000,
|
|
||||||
grammar=grammar,
|
|
||||||
).model_dump(),
|
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=None)
|
timeout=None)
|
||||||
|
|
||||||
|
@ -96,11 +110,11 @@ def main(
|
||||||
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
|
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
sys.stderr.write("# RESULT:\n\n" + json.dumps(result, indent=2) + "\n\n")
|
||||||
if 'content' not in result:
|
if 'content' not in result:
|
||||||
# print(json.dumps(result, indent=2))
|
# print(json.dumps(result, indent=2))
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
sys.stderr.write(json.dumps(result, indent=2) + "\n")
|
|
||||||
# print(json.dumps(result.get('content'), indent=2))
|
# print(json.dumps(result.get('content'), indent=2))
|
||||||
message = parser(result["content"])
|
message = parser(result["content"])
|
||||||
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
||||||
|
|
|
@ -14,12 +14,21 @@ class SchemaToTypeScriptConverter:
|
||||||
# // where to get weather.
|
# // where to get weather.
|
||||||
# location: string,
|
# location: string,
|
||||||
# }) => any;
|
# }) => any;
|
||||||
|
def _desc_comment(self, schema: dict):
|
||||||
|
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
||||||
|
return f'// {desc}\n' if desc else ''
|
||||||
|
|
||||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
|
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
|
||||||
|
if additional_properties == True:
|
||||||
|
additional_properties = {}
|
||||||
|
elif additional_properties == False:
|
||||||
|
additional_properties = None
|
||||||
|
|
||||||
return "{" + ', '.join([
|
return "{" + ', '.join([
|
||||||
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
f'{self._desc_comment(prop_schema)}{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
||||||
for prop_name, prop_schema in properties
|
for prop_name, prop_schema in properties
|
||||||
] + (
|
] + (
|
||||||
[f"[key: string]: {self.visit(additional_properties)}"]
|
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
|
||||||
if additional_properties is not None else []
|
if additional_properties is not None else []
|
||||||
)) + "}"
|
)) + "}"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue