Merge aea81772d6
into 9e359a4f47
This commit is contained in:
commit
8b8b491179
8 changed files with 486 additions and 18 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -92,3 +92,4 @@ examples/jeopardy/results.txt
|
|||
poetry.lock
|
||||
poetry.toml
|
||||
nppBackup
|
||||
functionary-test
|
||||
|
|
2
Makefile
2
Makefile
|
@ -719,7 +719,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||
server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/functionary.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
set(TARGET server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h)
|
||||
add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp functionary.hpp httplib.h)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
|
|
110
examples/server/functionary-test.cpp
Normal file
110
examples/server/functionary-test.cpp
Normal file
|
@ -0,0 +1,110 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "json.hpp"
|
||||
#include "functionary.hpp"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
/**
|
||||
* A simple test program that allow testing functionary.hpp without using server.
|
||||
* TODO: how to add this test to CI?
|
||||
*
|
||||
* Compile command: clear && g++ functionary-test.cpp -o functionary-test && ./functionary-test
|
||||
*/
|
||||
|
||||
std::string test_oai_input_json = R"(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast"
|
||||
}
|
||||
},
|
||||
"required": ["location", "format", "num_days"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Boston?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id":"get_car_price",
|
||||
"function": {
|
||||
"arguments": "{\"car_name\": \"Song\"}",
|
||||
"name": "get_car_price"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"id":"get_car_price",
|
||||
"function": {
|
||||
"arguments": "{\"car_name\": \"Tang\"}",
|
||||
"name": "get_car_price"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "get_car_price",
|
||||
"name": "get_car_price",
|
||||
"content": "{\"price\": {\"price\": \"$25000\"}}"
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "get_car_price",
|
||||
"name": "get_car_price",
|
||||
"content": "{\"price\": {\"price\": \"$20000\"}}"
|
||||
}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
|
||||
std::string test_response = R"(get_car_price
|
||||
<|content|>{"car_name": "Song"}
|
||||
<|from|>assistant
|
||||
<|recipient|>get_car_price
|
||||
<|content|>{"car_name": "Tang"}<|stop|>)";
|
||||
|
||||
int main() {
|
||||
auto test_oai_input = json::parse(test_oai_input_json);
|
||||
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input, true);
|
||||
std::cout << "\n" << prompt << "\n";
|
||||
|
||||
std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n";
|
||||
|
||||
return 0;
|
||||
}
|
317
examples/server/functionary.hpp
Normal file
317
examples/server/functionary.hpp
Normal file
|
@ -0,0 +1,317 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
/**
|
||||
* Integration with functionary model: https://github.com/MeetKai/functionary
|
||||
* Based on my research: https://github.com/ggerganov/llama.cpp/issues/5588
|
||||
*
|
||||
* A typical flow is:
|
||||
* - Step 1: user send request to model
|
||||
* - Step 2: model send back a response to user
|
||||
* - Step 3: model send back another response to function (optional)
|
||||
* - Step 4: function send its returned value to model
|
||||
* - Step 5: finally, model send final response back to user
|
||||
*/
|
||||
|
||||
#define FUNCTIONARY_FN_PROMPT "// Supported function definitions that should be called when necessary."
|
||||
#define FUNCTIONARY_RECIP_ALL "all"
|
||||
#define FUNCTIONARY_RECIP_NONE "no-tool-call"
|
||||
|
||||
namespace llama_functionary {
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json &body, const std::string &key, const T &default_value)
|
||||
{
|
||||
// Fallback null to default value
|
||||
return body.contains(key) && !body.at(key).is_null()
|
||||
? body.value(key, default_value)
|
||||
: default_value;
|
||||
}
|
||||
|
||||
inline std::string str_replace(const std::string & original, const std::string & search, const std::string & replacement) {
|
||||
size_t pos = original.find(search);
|
||||
if (pos != std::string::npos) {
|
||||
std::string result = original;
|
||||
result.replace(pos, search.length(), replacement);
|
||||
return result;
|
||||
}
|
||||
return original;
|
||||
}
|
||||
|
||||
inline std::vector<std::string> str_split(std::string str, const std::string & delimiter) {
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
std::vector<std::string> output;
|
||||
while ((pos = str.find(delimiter)) != std::string::npos) {
|
||||
token = str.substr(0, pos);
|
||||
output.push_back(token);
|
||||
str.erase(0, pos + delimiter.length());
|
||||
}
|
||||
output.push_back(str); // the rest
|
||||
return output;
|
||||
}
|
||||
|
||||
typedef struct message {
|
||||
std::string from; // can be "system", "user", "assistant" or name of function
|
||||
std::string recipient = FUNCTIONARY_RECIP_ALL;
|
||||
std::string content;
|
||||
bool has_stop = false;
|
||||
message() {}
|
||||
message(json oai_json) {
|
||||
from = json_value(oai_json, "role", std::string(""));
|
||||
if (from == "tool") {
|
||||
// response from function
|
||||
from = json_value(oai_json, "tool_call_id", std::string(""));
|
||||
}
|
||||
content = json_value(oai_json, "content", std::string(""));
|
||||
}
|
||||
message(std::string & prompt) {
|
||||
std::istringstream iss(prompt);
|
||||
std::string line;
|
||||
std::stringstream ss;
|
||||
int i = 0; // line number
|
||||
while (std::getline(iss, line)) {
|
||||
if (i == 0) {
|
||||
from = str_replace(line, "<|from|>", "");
|
||||
} else if (i == 1) {
|
||||
recipient = str_replace(line, "<|recipient|>", "");
|
||||
} else if (i == 2) {
|
||||
ss << str_replace(line, "<|content|>", "");
|
||||
} else {
|
||||
ss << "\n" << line;
|
||||
}
|
||||
++i;
|
||||
}
|
||||
has_stop = ss.str().find("<|stop|>") != std::string::npos;
|
||||
content = str_replace(ss.str(), "<|stop|>", "");
|
||||
}
|
||||
std::string to_prompt() {
|
||||
std::stringstream ss;
|
||||
ss << "<|from|>" << from << "\n";
|
||||
ss << "<|recipient|>" << recipient << "\n";
|
||||
ss << "<|content|>" << content;
|
||||
if (has_stop) {
|
||||
ss << "<|stop|>";
|
||||
}
|
||||
ss << "\n";
|
||||
return ss.str();
|
||||
}
|
||||
} message;
|
||||
|
||||
typedef struct function_param {
|
||||
std::string name;
|
||||
// type can be "string", "boolean", "number" (typescript types)
|
||||
// we do not support array for now
|
||||
std::string type;
|
||||
std::string desc;
|
||||
std::vector<json> allowed_values; // dynamic types
|
||||
bool required;
|
||||
function_param(std::string param_name, json & oai_json) {
|
||||
name = param_name;
|
||||
type = json_value(oai_json, "type", std::string());
|
||||
desc = json_value(oai_json, "description", std::string());
|
||||
if (oai_json.count("enum")) {
|
||||
allowed_values = oai_json["enum"];
|
||||
}
|
||||
}
|
||||
} function_param;
|
||||
|
||||
typedef struct function_def {
|
||||
std::string name;
|
||||
std::string desc;
|
||||
std::vector<function_param> params;
|
||||
// parameters.type must always be "object"
|
||||
function_def(json & oai_json) {
|
||||
std::string type = json_value(oai_json, "type", std::string());
|
||||
if (type != "function") {
|
||||
throw std::runtime_error("Only tool type \"function\" is supported");
|
||||
}
|
||||
// function
|
||||
json inner_json = json_value(oai_json, "function", json::object());
|
||||
name = json_value(inner_json, "name", std::string());
|
||||
desc = json_value(inner_json, "description", std::string());
|
||||
// function.parameters
|
||||
json parameters = json_value(inner_json, "parameters", json::object());
|
||||
std::string param_type = json_value(parameters, "type", std::string());
|
||||
if (param_type != "object") {
|
||||
throw std::runtime_error("Only parameters type \"object\" is supported");
|
||||
}
|
||||
// function.parameters.properties
|
||||
json properties = json_value(parameters, "properties", json::object());
|
||||
for (auto& it : properties.items()) {
|
||||
std::string curr_prop = it.key();
|
||||
json data = json_value(properties, curr_prop, json::object());
|
||||
function_param param(curr_prop, data);
|
||||
params.push_back(param);
|
||||
}
|
||||
// TODO: add required !!!!!!!!!!!!!!
|
||||
}
|
||||
} function_def;
|
||||
|
||||
// convert OAI type to typescript
|
||||
inline std::string oai_type_to_ts(std::string & type, std::vector<json> & allowed_values) {
|
||||
if (!allowed_values.empty()) {
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < allowed_values.size(); ++i) {
|
||||
ss << allowed_values[i];
|
||||
if (i < allowed_values.size() - 1) {
|
||||
ss << " | ";
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
// non-enum types
|
||||
if (type == "string" || type == "number" || type == "boolean") {
|
||||
return type; // natively supported
|
||||
} else if (type == "bool") {
|
||||
return "boolean";
|
||||
} else if (type == "integer" || type == "float" || type == "double") {
|
||||
return "number";
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported type: " + type);
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string serialize_function(function_def & fn) {
|
||||
std::stringstream ss;
|
||||
if (fn.name.empty()) {
|
||||
throw std::runtime_error("Function name is empty");
|
||||
}
|
||||
if (!fn.desc.empty()) {
|
||||
// TODO: what if the desc has multiple lines?
|
||||
ss << "// " << fn.desc << "\n";
|
||||
}
|
||||
ss << "type " << fn.name << " = (_: {\n";
|
||||
for (auto & param : fn.params) {
|
||||
if (!param.desc.empty()) {
|
||||
ss << "// " << param.desc << "\n";
|
||||
}
|
||||
ss << param.name << ": " << oai_type_to_ts(param.type, param.allowed_values) << ",\n";
|
||||
}
|
||||
// only support "any" return type for now
|
||||
ss << "}) => any;\n\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////
|
||||
// Main hooks, to be called in oai.hpp
|
||||
|
||||
inline std::string convert_oai_to_prompt(const json & body, bool add_ass, bool allow_tool = true) {
|
||||
std::stringstream ss;
|
||||
// convert function definitions
|
||||
std::vector<json> tools = json_value(body, "tools", json::array());
|
||||
if (!tools.empty()) {
|
||||
std::stringstream ss_fn;
|
||||
ss_fn << FUNCTIONARY_FN_PROMPT << "\n";
|
||||
ss_fn << "namespace functions {" << "\n\n";
|
||||
for (auto & tool : tools) {
|
||||
function_def fn(tool);
|
||||
ss_fn << serialize_function(fn);
|
||||
}
|
||||
ss_fn << "} // namespace functions";
|
||||
// construct the message
|
||||
message fn_def_msg;
|
||||
fn_def_msg.from = "system";
|
||||
fn_def_msg.recipient = FUNCTIONARY_RECIP_ALL;
|
||||
fn_def_msg.content = ss_fn.str();
|
||||
ss << fn_def_msg.to_prompt();
|
||||
}
|
||||
// convert history
|
||||
std::vector<json> messages = json_value(body, "messages", json::array());
|
||||
for (auto & msg_json : messages) {
|
||||
// TODO: how to detect where to put "<|stop|>"?
|
||||
if (msg_json.count("tool_calls")) {
|
||||
// assistant request to function call, now re-passed to history
|
||||
std::vector<json> tool_calls = msg_json["tool_calls"];
|
||||
for (size_t i = 0; i < tool_calls.size(); i++) {
|
||||
auto & tc = tool_calls[i];
|
||||
message msg;
|
||||
msg.from = tc["function"]["name"];
|
||||
msg.content = tc["function"]["arguments"];
|
||||
msg.has_stop = i == tool_calls.size() - 1; // last msg
|
||||
ss << msg.to_prompt();
|
||||
}
|
||||
} else {
|
||||
// all other types of message
|
||||
message msg(msg_json);
|
||||
msg.has_stop = msg.from == "assistant"; // add stop if this is single text message from assistant (not contains tool_calls)
|
||||
ss << msg.to_prompt();
|
||||
}
|
||||
}
|
||||
// add trailing assistant prompt
|
||||
if (add_ass) {
|
||||
ss << "<|from|>assistant\n<|recipient|>";
|
||||
if (!allow_tool) {
|
||||
ss << FUNCTIONARY_RECIP_NONE;
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
inline json convert_response_to_oai_choices(const std::string & content) {
|
||||
std::string input_full = content;
|
||||
std::string text_response;
|
||||
json tool_calls = json::array();
|
||||
// parse all turns
|
||||
std::vector<std::string> turns = str_split(input_full, "<|from|>");
|
||||
if (!turns.empty()) {
|
||||
// first turn may not have the assistant tag (because it was part of the prompt added by "add_ass"), we will put it back to parse the message
|
||||
// the "<|from|>" will be added later
|
||||
if (turns[0].find("<|recipient|>") == std::string::npos) {
|
||||
turns[0] = "assistant\n<|recipient|>" + turns[0];
|
||||
}
|
||||
}
|
||||
for (auto & turn : turns) {
|
||||
std::string turn_full = "<|from|>" + turn;
|
||||
message msg(turn_full);
|
||||
if (msg.from != "assistant") {
|
||||
continue; // this case should never happen
|
||||
}
|
||||
if (msg.recipient != FUNCTIONARY_RECIP_ALL && msg.recipient != FUNCTIONARY_RECIP_NONE) {
|
||||
// the assistant decide to call a tool (step 3)
|
||||
tool_calls.push_back(json{
|
||||
{"id", msg.recipient}, // TODO: maybe generate a random part?
|
||||
{"type", "function"},
|
||||
{"function", json{
|
||||
{"name", msg.recipient},
|
||||
{"arguments", msg.content},
|
||||
}},
|
||||
});
|
||||
} else {
|
||||
// the assistant just want to say something (step 2)
|
||||
text_response = msg.content;
|
||||
}
|
||||
}
|
||||
// build final response
|
||||
json choices = json::array();
|
||||
// TODO: technically, functionary can reponse both text + tool_call in one shot. But for some reasons, the original implementation of OpenAI only return only one, not both.
|
||||
if (tool_calls.size() > 0) {
|
||||
choices.push_back(json{
|
||||
{"index", 0},
|
||||
{"finish_reason", "tool_calls"},
|
||||
{"message", json{
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", tool_calls},
|
||||
}},
|
||||
});
|
||||
} else {
|
||||
choices.push_back(json{
|
||||
{"index", 0},
|
||||
{"finish_reason", "stop"},
|
||||
{"message", json{
|
||||
{"role", "assistant"},
|
||||
{"content", text_response},
|
||||
}},
|
||||
});
|
||||
}
|
||||
return choices;
|
||||
}
|
||||
|
||||
} // namespace llama_functionary
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "json.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "functionary.hpp"
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
|
@ -17,7 +18,8 @@ using json = nlohmann::json;
|
|||
inline static json oaicompat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json &body, /* openai api json semantics */
|
||||
const std::string &chat_template)
|
||||
const std::string &chat_template,
|
||||
bool enable_tool_calls)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
|
@ -32,7 +34,9 @@ inline static json oaicompat_completion_params_parse(
|
|||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
llama_sampling_params default_sparams;
|
||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
||||
llama_params["prompt"] = enable_tool_calls
|
||||
? llama_functionary::convert_oai_to_prompt(body, true)
|
||||
: format_chat(model, chat_template, body["messages"]);
|
||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
||||
|
@ -63,13 +67,19 @@ inline static json oaicompat_completion_params_parse(
|
|||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Ensure there is ChatML-specific end sequence among stop words
|
||||
llama_params["stop"].push_back("<|im_end|>");
|
||||
llama_params["stop"].push_back(enable_tool_calls
|
||||
? "<|stop|>" // functionary-specific: this model uses "<|stop|>" instead of "</s>"
|
||||
: "<|im_end|>" // Ensure there is ChatML-specific end sequence among stop words
|
||||
);
|
||||
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
|
||||
inline static json format_final_response_oaicompat(
|
||||
const json &request,
|
||||
const task_result &response,
|
||||
bool streaming,
|
||||
bool enable_tool_calls)
|
||||
{
|
||||
json result = response.result_json;
|
||||
|
||||
|
@ -84,14 +94,20 @@ inline static json format_final_response_oaicompat(const json &request, const ta
|
|||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choices =
|
||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}})
|
||||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}}});
|
||||
json choices;
|
||||
|
||||
if (enable_tool_calls) {
|
||||
choices = llama_functionary::convert_response_to_oai_choices(content);
|
||||
} else {
|
||||
choices = streaming
|
||||
? json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}})
|
||||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}}});
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ struct server_params
|
|||
int32_t read_timeout = 600;
|
||||
int32_t write_timeout = 600;
|
||||
bool slots_endpoint = true;
|
||||
bool enable_tool_calls = false;
|
||||
};
|
||||
|
||||
bool server_verbose = false;
|
||||
|
@ -2759,7 +2760,14 @@ int main(int argc, char **argv)
|
|||
LOG_INFO("model loaded", {});
|
||||
}
|
||||
|
||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
||||
// Check tool_call ability
|
||||
sparams.enable_tool_calls = check_model_support_tool_calls(llama.model);
|
||||
if (sparams.enable_tool_calls) {
|
||||
LOG_INFO("Current model supports functionary tool_calls", {});
|
||||
}
|
||||
|
||||
// custom chat template is not supplied
|
||||
if (sparams.chat_template.empty() && !sparams.enable_tool_calls) {
|
||||
// check if the template comes with the model is supported by us
|
||||
llama.validate_model_chat_template(sparams);
|
||||
}
|
||||
|
@ -2935,7 +2943,9 @@ int main(int argc, char **argv)
|
|||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
|
||||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template, sparams.enable_tool_calls);
|
||||
|
||||
// TODO: "enable_tool_calls" cannot be used with "stream" mode
|
||||
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
|
@ -2946,7 +2956,7 @@ int main(int argc, char **argv)
|
|||
task_result result = llama.queue_results.recv(task_id);
|
||||
|
||||
if (!result.error && result.stop) {
|
||||
json oaicompat_result = format_final_response_oaicompat(data, result);
|
||||
json oaicompat_result = format_final_response_oaicompat(data, result, false, sparams.enable_tool_calls);
|
||||
|
||||
res.set_content(oaicompat_result.dump(-1, ' ', false,
|
||||
json::error_handler_t::replace),
|
||||
|
|
|
@ -211,6 +211,20 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|||
return formatted_chat;
|
||||
}
|
||||
|
||||
// Detect if the model supports tool_calls
|
||||
inline bool check_model_support_tool_calls(const struct llama_model * model) {
|
||||
std::vector<char> model_template(2048, 0);
|
||||
std::string template_key = "tokenizer.chat_template";
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||
if (res < 0) {
|
||||
return false; // no template in model
|
||||
} else {
|
||||
model_template.resize(res);
|
||||
std::string tmpl(model_template.data(), model_template.size());
|
||||
return tmpl.find("<|recipient|>") != std::string::npos;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// work queue utils
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue