first working version

This commit is contained in:
ngxson 2024-02-24 12:29:06 +01:00
parent 2a0d74d52e
commit aea81772d6
5 changed files with 39 additions and 20 deletions

View file

@ -93,11 +93,7 @@ std::string test_oai_input_json = R"(
)";
std::string test_response = R"(<|from|>assistant
<|recipient|>all
<|content|>I will get the price of 2 cars and compare
<|from|>assistant
<|recipient|>get_car_price
std::string test_response = R"(get_car_price
<|content|>{"car_name": "Song"}
<|from|>assistant
<|recipient|>get_car_price
@ -105,7 +101,7 @@ std::string test_response = R"(<|from|>assistant
int main() {
auto test_oai_input = json::parse(test_oai_input_json);
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input);
auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input, true);
std::cout << "\n" << prompt << "\n";
std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n";

View file

@ -9,6 +9,7 @@ using json = nlohmann::json;
/**
* Integration with functionary model: https://github.com/MeetKai/functionary
* Based on my research: https://github.com/ggerganov/llama.cpp/issues/5588
*
* A typical flow is:
* - Step 1: user send request to model
@ -201,7 +202,7 @@ inline std::string serialize_function(function_def & fn) {
///////////////////////////////////////////
// Main hooks, to be called in oai.hpp
inline std::string convert_oai_to_prompt(const json & body) {
inline std::string convert_oai_to_prompt(const json & body, bool add_ass, bool allow_tool = true) {
std::stringstream ss;
// convert function definitions
std::vector<json> tools = json_value(body, "tools", json::array());
@ -224,30 +225,48 @@ inline std::string convert_oai_to_prompt(const json & body) {
// convert history
std::vector<json> messages = json_value(body, "messages", json::array());
for (auto & msg_json : messages) {
// TODO: how to detect where to put "<|stop|>"?
if (msg_json.count("tool_calls")) {
// assistant request to function call, now re-passed to history
std::vector<json> tool_calls = msg_json["tool_calls"];
for (auto & tc : tool_calls) {
for (size_t i = 0; i < tool_calls.size(); i++) {
auto & tc = tool_calls[i];
message msg;
msg.from = tc["function"]["name"];
msg.content = tc["function"]["arguments"];
msg.has_stop = i == tool_calls.size() - 1; // last msg
ss << msg.to_prompt();
}
} else {
// all other types of message
message msg(msg_json);
msg.has_stop = msg.from == "assistant"; // add stop if this is single text message from assistant (not contains tool_calls)
ss << msg.to_prompt();
}
}
// add trailing assistant prompt
if (add_ass) {
ss << "<|from|>assistant\n<|recipient|>";
if (!allow_tool) {
ss << FUNCTIONARY_RECIP_NONE;
}
}
return ss.str();
}
// be careful, the assistant output does not have "<|from|>assistant", you need to add it yourself!
inline json convert_response_to_oai_choices(const std::string & content) {
std::string input_full = content;
std::string text_response;
json tool_calls = json::array();
// parse all turns
std::vector<std::string> turns = str_split(content, "<|from|>");
std::vector<std::string> turns = str_split(input_full, "<|from|>");
if (!turns.empty()) {
// first turn may not have the assistant tag (because it was part of the prompt added by "add_ass"), we will put it back to parse the message
// the "<|from|>" will be added later
if (turns[0].find("<|recipient|>") == std::string::npos) {
turns[0] = "assistant\n<|recipient|>" + turns[0];
}
}
for (auto & turn : turns) {
std::string turn_full = "<|from|>" + turn;
message msg(turn_full);

View file

@ -35,7 +35,7 @@ inline static json oaicompat_completion_params_parse(
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = enable_tool_calls
? llama_functionary::convert_oai_to_prompt(body)
? llama_functionary::convert_oai_to_prompt(body, true)
: format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
@ -67,8 +67,10 @@ inline static json oaicompat_completion_params_parse(
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Ensure there is ChatML-specific end sequence among stop words
llama_params["stop"].push_back("<|im_end|>");
llama_params["stop"].push_back(enable_tool_calls
? "<|stop|>" // functionary-specific: this model uses "<|stop|>" instead of "</s>"
: "<|im_end|>" // Ensure there is ChatML-specific end sequence among stop words
);
return llama_params;
}
@ -104,7 +106,7 @@ inline static json format_final_response_oaicompat(
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
{"role", "assistant"}}}}});
}
std::time_t t = std::time(0);

View file

@ -2773,15 +2773,16 @@ int main(int argc, char **argv)
LOG_INFO("model loaded", {});
}
if (sparams.chat_template.empty()) { // custom chat template is not supplied
// check if the template comes with the model is supported by us
llama.validate_model_chat_template(sparams);
}
// Check tool_call ability
sparams.enable_tool_calls = check_model_support_tool_calls(llama.model);
if (sparams.enable_tool_calls) {
LOG_VERBOSE("Current model supports functionary tool_calls", {});
LOG_INFO("Current model supports functionary tool_calls", {});
}
// custom chat template is not supplied
if (sparams.chat_template.empty() && !sparams.enable_tool_calls) {
// check if the template comes with the model is supported by us
llama.validate_model_chat_template(sparams);
}
// Middleware for API key validation

View file

@ -219,6 +219,7 @@ inline bool check_model_support_tool_calls(const struct llama_model * model) {
if (res < 0) {
return false; // no template in model
} else {
model_template.resize(res);
std::string tmpl(model_template.data(), model_template.size());
return tmpl.find("<|recipient|>") != std::string::npos;
}