first working version
This commit is contained in:
parent
2a0d74d52e
commit
aea81772d6
5 changed files with 39 additions and 20 deletions
|
@ -93,11 +93,7 @@ std::string test_oai_input_json = R"(
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
|
||||||
std::string test_response = R"(<|from|>assistant
|
std::string test_response = R"(get_car_price
|
||||||
<|recipient|>all
|
|
||||||
<|content|>I will get the price of 2 cars and compare
|
|
||||||
<|from|>assistant
|
|
||||||
<|recipient|>get_car_price
|
|
||||||
<|content|>{"car_name": "Song"}
|
<|content|>{"car_name": "Song"}
|
||||||
<|from|>assistant
|
<|from|>assistant
|
||||||
<|recipient|>get_car_price
|
<|recipient|>get_car_price
|
||||||
|
@ -105,7 +101,7 @@ std::string test_response = R"(<|from|>assistant
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
auto test_oai_input = json::parse(test_oai_input_json);
|
auto test_oai_input = json::parse(test_oai_input_json);
|
||||||
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input);
|
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input, true);
|
||||||
std::cout << "\n" << prompt << "\n";
|
std::cout << "\n" << prompt << "\n";
|
||||||
|
|
||||||
std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n";
|
std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n";
|
||||||
|
|
|
@ -9,6 +9,7 @@ using json = nlohmann::json;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Integration with functionary model: https://github.com/MeetKai/functionary
|
* Integration with functionary model: https://github.com/MeetKai/functionary
|
||||||
|
* Based on my research: https://github.com/ggerganov/llama.cpp/issues/5588
|
||||||
*
|
*
|
||||||
* A typical flow is:
|
* A typical flow is:
|
||||||
* - Step 1: user send request to model
|
* - Step 1: user send request to model
|
||||||
|
@ -201,7 +202,7 @@ inline std::string serialize_function(function_def & fn) {
|
||||||
///////////////////////////////////////////
|
///////////////////////////////////////////
|
||||||
// Main hooks, to be called in oai.hpp
|
// Main hooks, to be called in oai.hpp
|
||||||
|
|
||||||
inline std::string convert_oai_to_prompt(const json & body) {
|
inline std::string convert_oai_to_prompt(const json & body, bool add_ass, bool allow_tool = true) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
// convert function definitions
|
// convert function definitions
|
||||||
std::vector<json> tools = json_value(body, "tools", json::array());
|
std::vector<json> tools = json_value(body, "tools", json::array());
|
||||||
|
@ -224,30 +225,48 @@ inline std::string convert_oai_to_prompt(const json & body) {
|
||||||
// convert history
|
// convert history
|
||||||
std::vector<json> messages = json_value(body, "messages", json::array());
|
std::vector<json> messages = json_value(body, "messages", json::array());
|
||||||
for (auto & msg_json : messages) {
|
for (auto & msg_json : messages) {
|
||||||
|
// TODO: how to detect where to put "<|stop|>"?
|
||||||
if (msg_json.count("tool_calls")) {
|
if (msg_json.count("tool_calls")) {
|
||||||
// assistant request to function call, now re-passed to history
|
// assistant request to function call, now re-passed to history
|
||||||
std::vector<json> tool_calls = msg_json["tool_calls"];
|
std::vector<json> tool_calls = msg_json["tool_calls"];
|
||||||
for (auto & tc : tool_calls) {
|
for (size_t i = 0; i < tool_calls.size(); i++) {
|
||||||
|
auto & tc = tool_calls[i];
|
||||||
message msg;
|
message msg;
|
||||||
msg.from = tc["function"]["name"];
|
msg.from = tc["function"]["name"];
|
||||||
msg.content = tc["function"]["arguments"];
|
msg.content = tc["function"]["arguments"];
|
||||||
|
msg.has_stop = i == tool_calls.size() - 1; // last msg
|
||||||
ss << msg.to_prompt();
|
ss << msg.to_prompt();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// all other types of message
|
// all other types of message
|
||||||
message msg(msg_json);
|
message msg(msg_json);
|
||||||
|
msg.has_stop = msg.from == "assistant"; // add stop if this is single text message from assistant (not contains tool_calls)
|
||||||
ss << msg.to_prompt();
|
ss << msg.to_prompt();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// add trailing assistant prompt
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<|from|>assistant\n<|recipient|>";
|
||||||
|
if (!allow_tool) {
|
||||||
|
ss << FUNCTIONARY_RECIP_NONE;
|
||||||
|
}
|
||||||
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// be careful, the assistant output does not have "<|from|>assistant", you need to add it yourself!
|
|
||||||
inline json convert_response_to_oai_choices(const std::string & content) {
|
inline json convert_response_to_oai_choices(const std::string & content) {
|
||||||
|
std::string input_full = content;
|
||||||
std::string text_response;
|
std::string text_response;
|
||||||
json tool_calls = json::array();
|
json tool_calls = json::array();
|
||||||
// parse all turns
|
// parse all turns
|
||||||
std::vector<std::string> turns = str_split(content, "<|from|>");
|
std::vector<std::string> turns = str_split(input_full, "<|from|>");
|
||||||
|
if (!turns.empty()) {
|
||||||
|
// first turn may not have the assistant tag (because it was part of the prompt added by "add_ass"), we will put it back to parse the message
|
||||||
|
// the "<|from|>" will be added later
|
||||||
|
if (turns[0].find("<|recipient|>") == std::string::npos) {
|
||||||
|
turns[0] = "assistant\n<|recipient|>" + turns[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (auto & turn : turns) {
|
for (auto & turn : turns) {
|
||||||
std::string turn_full = "<|from|>" + turn;
|
std::string turn_full = "<|from|>" + turn;
|
||||||
message msg(turn_full);
|
message msg(turn_full);
|
||||||
|
|
|
@ -35,7 +35,7 @@ inline static json oaicompat_completion_params_parse(
|
||||||
llama_sampling_params default_sparams;
|
llama_sampling_params default_sparams;
|
||||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||||
llama_params["prompt"] = enable_tool_calls
|
llama_params["prompt"] = enable_tool_calls
|
||||||
? llama_functionary::convert_oai_to_prompt(body)
|
? llama_functionary::convert_oai_to_prompt(body, true)
|
||||||
: format_chat(model, chat_template, body["messages"]);
|
: format_chat(model, chat_template, body["messages"]);
|
||||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||||
|
@ -67,8 +67,10 @@ inline static json oaicompat_completion_params_parse(
|
||||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure there is ChatML-specific end sequence among stop words
|
llama_params["stop"].push_back(enable_tool_calls
|
||||||
llama_params["stop"].push_back("<|im_end|>");
|
? "<|stop|>" // functionary-specific: this model uses "<|stop|>" instead of "</s>"
|
||||||
|
: "<|im_end|>" // Ensure there is ChatML-specific end sequence among stop words
|
||||||
|
);
|
||||||
|
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
@ -104,7 +106,7 @@ inline static json format_final_response_oaicompat(
|
||||||
: json::array({json{{"finish_reason", finish_reason},
|
: json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"message", json{{"content", content},
|
{"message", json{{"content", content},
|
||||||
{"role", "assistant"}}}}});
|
{"role", "assistant"}}}}});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
|
@ -2773,15 +2773,16 @@ int main(int argc, char **argv)
|
||||||
LOG_INFO("model loaded", {});
|
LOG_INFO("model loaded", {});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
|
||||||
// check if the template comes with the model is supported by us
|
|
||||||
llama.validate_model_chat_template(sparams);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check tool_call ability
|
// Check tool_call ability
|
||||||
sparams.enable_tool_calls = check_model_support_tool_calls(llama.model);
|
sparams.enable_tool_calls = check_model_support_tool_calls(llama.model);
|
||||||
if (sparams.enable_tool_calls) {
|
if (sparams.enable_tool_calls) {
|
||||||
LOG_VERBOSE("Current model supports functionary tool_calls", {});
|
LOG_INFO("Current model supports functionary tool_calls", {});
|
||||||
|
}
|
||||||
|
|
||||||
|
// custom chat template is not supplied
|
||||||
|
if (sparams.chat_template.empty() && !sparams.enable_tool_calls) {
|
||||||
|
// check if the template comes with the model is supported by us
|
||||||
|
llama.validate_model_chat_template(sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware for API key validation
|
// Middleware for API key validation
|
||||||
|
|
|
@ -219,6 +219,7 @@ inline bool check_model_support_tool_calls(const struct llama_model * model) {
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
return false; // no template in model
|
return false; // no template in model
|
||||||
} else {
|
} else {
|
||||||
|
model_template.resize(res);
|
||||||
std::string tmpl(model_template.data(), model_template.size());
|
std::string tmpl(model_template.data(), model_template.size());
|
||||||
return tmpl.find("<|recipient|>") != std::string::npos;
|
return tmpl.find("<|recipient|>") != std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue