first working version
This commit is contained in:
parent
2a0d74d52e
commit
aea81772d6
5 changed files with 39 additions and 20 deletions
|
@ -93,11 +93,7 @@ std::string test_oai_input_json = R"(
|
|||
)";
|
||||
|
||||
|
||||
std::string test_response = R"(<|from|>assistant
|
||||
<|recipient|>all
|
||||
<|content|>I will get the price of 2 cars and compare
|
||||
<|from|>assistant
|
||||
<|recipient|>get_car_price
|
||||
std::string test_response = R"(get_car_price
|
||||
<|content|>{"car_name": "Song"}
|
||||
<|from|>assistant
|
||||
<|recipient|>get_car_price
|
||||
|
@ -105,7 +101,7 @@ std::string test_response = R"(<|from|>assistant
|
|||
|
||||
int main() {
|
||||
auto test_oai_input = json::parse(test_oai_input_json);
|
||||
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input);
|
||||
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input, true);
|
||||
std::cout << "\n" << prompt << "\n";
|
||||
|
||||
std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n";
|
||||
|
|
|
@ -9,6 +9,7 @@ using json = nlohmann::json;
|
|||
|
||||
/**
|
||||
* Integration with functionary model: https://github.com/MeetKai/functionary
|
||||
* Based on my research: https://github.com/ggerganov/llama.cpp/issues/5588
|
||||
*
|
||||
* A typical flow is:
|
||||
* - Step 1: user send request to model
|
||||
|
@ -201,7 +202,7 @@ inline std::string serialize_function(function_def & fn) {
|
|||
///////////////////////////////////////////
|
||||
// Main hooks, to be called in oai.hpp
|
||||
|
||||
inline std::string convert_oai_to_prompt(const json & body) {
|
||||
inline std::string convert_oai_to_prompt(const json & body, bool add_ass, bool allow_tool = true) {
|
||||
std::stringstream ss;
|
||||
// convert function definitions
|
||||
std::vector<json> tools = json_value(body, "tools", json::array());
|
||||
|
@ -224,30 +225,48 @@ inline std::string convert_oai_to_prompt(const json & body) {
|
|||
// convert history
|
||||
std::vector<json> messages = json_value(body, "messages", json::array());
|
||||
for (auto & msg_json : messages) {
|
||||
// TODO: how to detect where to put "<|stop|>"?
|
||||
if (msg_json.count("tool_calls")) {
|
||||
// assistant request to function call, now re-passed to history
|
||||
std::vector<json> tool_calls = msg_json["tool_calls"];
|
||||
for (auto & tc : tool_calls) {
|
||||
for (size_t i = 0; i < tool_calls.size(); i++) {
|
||||
auto & tc = tool_calls[i];
|
||||
message msg;
|
||||
msg.from = tc["function"]["name"];
|
||||
msg.content = tc["function"]["arguments"];
|
||||
msg.has_stop = i == tool_calls.size() - 1; // last msg
|
||||
ss << msg.to_prompt();
|
||||
}
|
||||
} else {
|
||||
// all other types of message
|
||||
message msg(msg_json);
|
||||
msg.has_stop = msg.from == "assistant"; // add stop if this is single text message from assistant (not contains tool_calls)
|
||||
ss << msg.to_prompt();
|
||||
}
|
||||
}
|
||||
// add trailing assistant prompt
|
||||
if (add_ass) {
|
||||
ss << "<|from|>assistant\n<|recipient|>";
|
||||
if (!allow_tool) {
|
||||
ss << FUNCTIONARY_RECIP_NONE;
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// be careful, the assistant output does not have "<|from|>assistant", you need to add it yourself!
|
||||
inline json convert_response_to_oai_choices(const std::string & content) {
|
||||
std::string input_full = content;
|
||||
std::string text_response;
|
||||
json tool_calls = json::array();
|
||||
// parse all turns
|
||||
std::vector<std::string> turns = str_split(content, "<|from|>");
|
||||
std::vector<std::string> turns = str_split(input_full, "<|from|>");
|
||||
if (!turns.empty()) {
|
||||
// first turn may not have the assistant tag (because it was part of the prompt added by "add_ass"), we will put it back to parse the message
|
||||
// the "<|from|>" will be added later
|
||||
if (turns[0].find("<|recipient|>") == std::string::npos) {
|
||||
turns[0] = "assistant\n<|recipient|>" + turns[0];
|
||||
}
|
||||
}
|
||||
for (auto & turn : turns) {
|
||||
std::string turn_full = "<|from|>" + turn;
|
||||
message msg(turn_full);
|
||||
|
|
|
@ -35,7 +35,7 @@ inline static json oaicompat_completion_params_parse(
|
|||
llama_sampling_params default_sparams;
|
||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||
llama_params["prompt"] = enable_tool_calls
|
||||
? llama_functionary::convert_oai_to_prompt(body)
|
||||
? llama_functionary::convert_oai_to_prompt(body, true)
|
||||
: format_chat(model, chat_template, body["messages"]);
|
||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
|
@ -67,8 +67,10 @@ inline static json oaicompat_completion_params_parse(
|
|||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Ensure there is ChatML-specific end sequence among stop words
|
||||
llama_params["stop"].push_back("<|im_end|>");
|
||||
llama_params["stop"].push_back(enable_tool_calls
|
||||
? "<|stop|>" // functionary-specific: this model uses "<|stop|>" instead of "</s>"
|
||||
: "<|im_end|>" // Ensure there is ChatML-specific end sequence among stop words
|
||||
);
|
||||
|
||||
return llama_params;
|
||||
}
|
||||
|
@ -104,7 +106,7 @@ inline static json format_final_response_oaicompat(
|
|||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}}});
|
||||
{"role", "assistant"}}}}});
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
|
|
@ -2773,15 +2773,16 @@ int main(int argc, char **argv)
|
|||
LOG_INFO("model loaded", {});
|
||||
}
|
||||
|
||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
||||
// check if the template comes with the model is supported by us
|
||||
llama.validate_model_chat_template(sparams);
|
||||
}
|
||||
|
||||
// Check tool_call ability
|
||||
sparams.enable_tool_calls = check_model_support_tool_calls(llama.model);
|
||||
if (sparams.enable_tool_calls) {
|
||||
LOG_VERBOSE("Current model supports functionary tool_calls", {});
|
||||
LOG_INFO("Current model supports functionary tool_calls", {});
|
||||
}
|
||||
|
||||
// custom chat template is not supplied
|
||||
if (sparams.chat_template.empty() && !sparams.enable_tool_calls) {
|
||||
// check if the template comes with the model is supported by us
|
||||
llama.validate_model_chat_template(sparams);
|
||||
}
|
||||
|
||||
// Middleware for API key validation
|
||||
|
|
|
@ -219,6 +219,7 @@ inline bool check_model_support_tool_calls(const struct llama_model * model) {
|
|||
if (res < 0) {
|
||||
return false; // no template in model
|
||||
} else {
|
||||
model_template.resize(res);
|
||||
std::string tmpl(model_template.data(), model_template.size());
|
||||
return tmpl.find("<|recipient|>") != std::string::npos;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue