server.py: reenable grammar, accommodate mistral's escaped underscores
This commit is contained in:
parent
aa9605c751
commit
a4062935a5
3 changed files with 166 additions and 54 deletions
|
@ -128,10 +128,12 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M
|
||||||
'',
|
'',
|
||||||
'''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
|
'''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
|
||||||
'',
|
'',
|
||||||
'''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
# '''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
||||||
|
'''To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
||||||
'''<tool_call>''',
|
'''<tool_call>''',
|
||||||
'''{"arguments": <args-dict>, "name": <function-name>}''',
|
'''{"arguments": <args-dict>, "name": <function-name>}''',
|
||||||
'''</tool_call>''',
|
'''</tool_call>''',
|
||||||
|
'''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it.''',
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -201,17 +203,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
if s.endswith(suffix):
|
if s.endswith(suffix):
|
||||||
return s[:-len(suffix)]
|
return s[:-len(suffix)]
|
||||||
else:
|
else:
|
||||||
print(f"Expected suffix ({suffix}) not found: {s}")
|
sys.stderr.write(f"Expected suffix ({suffix}) not found: {s}\n")
|
||||||
return s
|
return s
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
if _outputs_tool_call_tags(chat_format.tool_style):
|
if _outputs_tool_call_tags(chat_format.tool_style):
|
||||||
|
|
||||||
|
escapes_underscores = chat_format.tool_style != ToolsPromptStyle.TOOLS_HERMES_2_PRO
|
||||||
|
|
||||||
tool_rules = [
|
tool_rules = [
|
||||||
converter.visit(
|
converter.visit(
|
||||||
dict(
|
dict(
|
||||||
type="object",
|
type="object",
|
||||||
properties=dict(
|
properties=dict(
|
||||||
name=dict(const=tool.function.name),
|
name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
|
||||||
|
else dict(const=tool.function.name),
|
||||||
arguments=tool.function.parameters,
|
arguments=tool.function.parameters,
|
||||||
),
|
),
|
||||||
required=['name', 'arguments']
|
required=['name', 'arguments']
|
||||||
|
@ -221,22 +227,45 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
for tool in tools
|
for tool in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
# Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
def format_literal(s: str) -> str:
|
||||||
# OR a tool-call message respecting the schema of any of the tools
|
if escapes_underscores:
|
||||||
|
return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
|
||||||
|
else:
|
||||||
|
return converter._format_literal(s)
|
||||||
|
|
||||||
|
tool_call_rule = converter._add_rule(
|
||||||
|
'tool_call',
|
||||||
|
format_literal("<tool_call>") + " (" +
|
||||||
|
' | '.join(tool_rules) +
|
||||||
|
") " + format_literal("</tool_call>"))
|
||||||
|
|
||||||
|
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
||||||
|
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
||||||
|
content_rule = converter._add_rule('content', '[^<] | "<" [^t<]? | "<t" [^o<]?')
|
||||||
|
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
"root",
|
'root',
|
||||||
converter._format_literal(prefix) + " (" +
|
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?')
|
||||||
(response_rule or converter.not_literal("<tool_call>")) + " | " +
|
|
||||||
converter._format_literal("<tool_call>") + " (" +
|
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
||||||
' | '.join(tool_rules) +
|
# # OR a tool-call message respecting the schema of any of the tools
|
||||||
") " + converter._format_literal("</tool_call>") +
|
# converter._add_rule(
|
||||||
")") # + converter._format_literal(suffix))
|
# "root",
|
||||||
|
# converter._format_literal(prefix) + " (" +
|
||||||
|
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
|
||||||
|
# converter._format_literal("<tool_call>") + " (" +
|
||||||
|
# ' | '.join(tool_rules) +
|
||||||
|
# ") " + converter._format_literal("</tool_call>") +
|
||||||
|
# ")") # + converter._format_literal(suffix))
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
s = strip_suffix(s)
|
s = strip_suffix(s)
|
||||||
|
|
||||||
# ls = s.lstrip()
|
if r'<tool\_call>' in s:
|
||||||
|
# Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
|
||||||
|
s = s.replace(r'\_', '_')
|
||||||
|
|
||||||
parts = _tool_call_re.split(s)
|
parts = _tool_call_re.split(s)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
return Message(role="assistant", content=s)
|
return Message(role="assistant", content=s)
|
||||||
|
@ -247,13 +276,17 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
content.append(part)
|
content.append(part)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
fc = json.loads(part)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(**json.loads(part))))
|
function=FunctionCall(**fc)))
|
||||||
|
|
||||||
content = ''.join(content).strip()
|
content = '(...)'.join(content).strip()
|
||||||
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
||||||
|
|
||||||
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||||
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
||||||
|
@ -268,17 +301,54 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
# Only allowing a single tool call at a time for now.
|
# Only allowing a single tool call at a time for now.
|
||||||
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
|
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
|
||||||
|
|
||||||
|
tool_rules = [
|
||||||
|
converter._add_rule(
|
||||||
|
tool.function.name + '-call',
|
||||||
|
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
||||||
|
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
||||||
|
converter._format_literal('\n'))
|
||||||
|
# converter.visit(
|
||||||
|
# dict(
|
||||||
|
# type="object",
|
||||||
|
# properties=dict(
|
||||||
|
# name=dict(const=tool.function.name),
|
||||||
|
# arguments=tool.function.parameters,
|
||||||
|
# ),
|
||||||
|
# required=['name', 'arguments']
|
||||||
|
# ),
|
||||||
|
# f'{tool.function.name}-tool-call'
|
||||||
|
# )
|
||||||
|
for i, tool in enumerate(tools)
|
||||||
|
]
|
||||||
|
|
||||||
|
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
|
||||||
|
content_without_start_rule = converter._add_rule('content_without_start', converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
|
||||||
|
start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
|
||||||
|
content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
|
||||||
|
tool_call_without_start_rule = converter._add_rule(
|
||||||
|
'tool_call_without_start',
|
||||||
|
' | '.join(tool_rules))
|
||||||
|
# + ' ' +
|
||||||
|
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
|
||||||
|
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
|
||||||
|
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
"root",
|
'root',
|
||||||
converter._format_literal(prefix) + " (" +
|
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
||||||
(response_rule or converter.not_literal("<|recipient|>")) + " | " +
|
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*')
|
||||||
(' | '.join(
|
|
||||||
converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
|
# converter._add_rule(
|
||||||
converter.visit(tool.function.parameters, tool.function.name + '-args')
|
# "root",
|
||||||
for tool in tools
|
# converter._format_literal(prefix) + " (" +
|
||||||
)) +
|
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
|
||||||
") " +
|
# (' | '.join(
|
||||||
")") # + converter._format_literal(suffix))
|
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
|
||||||
|
# converter.visit(tool.function.parameters, tool.function.name + '-args')
|
||||||
|
# for tool in tools
|
||||||
|
# )) +
|
||||||
|
# ") " +
|
||||||
|
# ")") # + converter._format_literal(suffix))
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
|
@ -297,17 +367,25 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
if recipient == 'all':
|
if recipient == 'all':
|
||||||
text_content.append(content)
|
text_content.append(content)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
arguments = json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f'Failed to parse tool call content as JSON: {content}')
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(name=recipient, arguments=json.loads(content))))
|
function=FunctionCall(name=recipient, arguments=arguments)))
|
||||||
|
|
||||||
|
|
||||||
assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}'
|
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
|
||||||
|
|
||||||
content = '\n'.join(text_content).strip()
|
content = '\n'.join(text_content).strip()
|
||||||
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None)
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
return (converter.format_grammar(), parse)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
|
||||||
|
|
||||||
elif response_schema:
|
elif response_schema:
|
||||||
converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix))
|
converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix))
|
||||||
|
|
|
@ -30,27 +30,28 @@ def main(
|
||||||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
port: int = 8080,
|
port: int = 8080,
|
||||||
main_server_endpoint: Optional[str] = None,
|
cpp_server_endpoint: Optional[str] = None,
|
||||||
main_server_host: str = "localhost",
|
cpp_server_host: str = "localhost",
|
||||||
main_server_port: Optional[int] = 8081,
|
cpp_server_port: Optional[int] = 8081,
|
||||||
):
|
):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
metadata = GGUFKeyValues(model)
|
metadata = GGUFKeyValues(model)
|
||||||
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
||||||
chat_format = ChatFormat.from_gguf(metadata)
|
chat_format = ChatFormat.from_gguf(metadata)
|
||||||
print(chat_format)
|
# print(chat_format)
|
||||||
|
|
||||||
if not main_server_endpoint:
|
if not cpp_server_endpoint:
|
||||||
|
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
|
||||||
server_process = subprocess.Popen([
|
server_process = subprocess.Popen([
|
||||||
"./server", "-m", model,
|
"./server", "-m", model,
|
||||||
"--host", main_server_host, "--port", f'{main_server_port}',
|
"--host", cpp_server_host, "--port", f'{cpp_server_port}',
|
||||||
'-ctk', 'q4_0', '-ctv', 'f16',
|
'-ctk', 'q4_0', '-ctv', 'f16',
|
||||||
"-c", f"{2*8192}",
|
"-c", f"{2*8192}",
|
||||||
# "-c", f"{context_length}",
|
# "-c", f"{context_length}",
|
||||||
])
|
], stdout=sys.stderr)
|
||||||
atexit.register(server_process.kill)
|
atexit.register(server_process.kill)
|
||||||
main_server_endpoint = f"http://{main_server_host}:{main_server_port}"
|
cpp_server_endpoint = f"http://{cpp_server_host}:{cpp_server_port}"
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -74,21 +75,17 @@ def main(
|
||||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
||||||
|
|
||||||
# TODO: Test whether the template supports formatting tool_calls
|
# TODO: Test whether the template supports formatting tool_calls
|
||||||
|
sys.stderr.write(f'\n{grammar}\n\n')
|
||||||
|
|
||||||
prompt = chat_format.render(messages, add_generation_prompt=True)
|
prompt = chat_format.render(messages, add_generation_prompt=True)
|
||||||
print(json.dumps(dict(
|
|
||||||
stream=chat_request.stream,
|
|
||||||
prompt=prompt,
|
|
||||||
# grammar=grammar,
|
|
||||||
), indent=2))
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{main_server_endpoint}/completions",
|
f"{cpp_server_endpoint}/completions",
|
||||||
json=LlamaCppServerCompletionRequest(
|
json=LlamaCppServerCompletionRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=chat_request.stream,
|
stream=chat_request.stream,
|
||||||
n_predict=300,
|
n_predict=300,
|
||||||
# grammar=grammar,
|
grammar=grammar,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=None)
|
timeout=None)
|
||||||
|
@ -103,7 +100,7 @@ def main(
|
||||||
# print(json.dumps(result, indent=2))
|
# print(json.dumps(result, indent=2))
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
print(json.dumps(result, indent=2))
|
sys.stderr.write(json.dumps(result, indent=2) + "\n")
|
||||||
# print(json.dumps(result.get('content'), indent=2))
|
# print(json.dumps(result.get('content'), indent=2))
|
||||||
message = parser(result["content"])
|
message = parser(result["content"])
|
||||||
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
||||||
|
@ -118,7 +115,6 @@ def main(
|
||||||
choices=[Choice(
|
choices=[Choice(
|
||||||
index=0,
|
index=0,
|
||||||
message=message,
|
message=message,
|
||||||
|
|
||||||
finish_reason="stop" if message.tool_calls is None else "tool_calls",
|
finish_reason="stop" if message.tool_calls is None else "tool_calls",
|
||||||
)],
|
)],
|
||||||
usage=Usage(
|
usage=Usage(
|
||||||
|
|
|
@ -4,23 +4,60 @@ set -euo pipefail
|
||||||
SERVER_PID=""
|
SERVER_PID=""
|
||||||
function cleanup() {
|
function cleanup() {
|
||||||
if [ -n "$SERVER_PID" ]; then
|
if [ -n "$SERVER_PID" ]; then
|
||||||
echo "# Killing server"
|
echo "# Killing server" >&2
|
||||||
kill $SERVER_PID
|
kill $SERVER_PID
|
||||||
wait $SERVER_PID
|
wait $SERVER_PID
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
echo "# Starting the server"
|
echo "# Starting the server" >&2
|
||||||
|
|
||||||
python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf &
|
args=(
|
||||||
# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
# --cpp_server_endpoint "http://localhost:8081"
|
||||||
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
|
||||||
|
# --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf
|
||||||
|
|
||||||
|
# --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf
|
||||||
|
# --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf
|
||||||
|
|
||||||
|
# --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf
|
||||||
|
--model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
|
||||||
|
)
|
||||||
|
python -m examples.openai "${args[@]}" &
|
||||||
SERVER_PID=$!
|
SERVER_PID=$!
|
||||||
|
|
||||||
sleep 5
|
sleep 5
|
||||||
|
|
||||||
echo "# Send a message to the chat API"
|
echo "# Send a message to the chat API" >&2
|
||||||
|
|
||||||
|
# curl http://localhost:8080/v1/chat/completions \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||||
|
# -d '{
|
||||||
|
# "model": "gpt-3.5-turbo",
|
||||||
|
# "tools": [{
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "description": "Get the current weather",
|
||||||
|
# "parameters": {
|
||||||
|
# "type": "object",
|
||||||
|
# "properties": {
|
||||||
|
# "location": {
|
||||||
|
# "type": "string",
|
||||||
|
# "description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
# }
|
||||||
|
# },
|
||||||
|
# "required": ["location"]
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# }],
|
||||||
|
# "messages": [
|
||||||
|
# {"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||||
|
# ]
|
||||||
|
# }' | \
|
||||||
|
# jq .
|
||||||
|
|
||||||
curl http://localhost:8080/v1/chat/completions \
|
curl http://localhost:8080/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -77,6 +114,7 @@ curl http://localhost:8080/v1/chat/completions \
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
{"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||||
]
|
]
|
||||||
}'
|
}' | \
|
||||||
|
jq .
|
||||||
|
|
||||||
# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue