minor fix to address tools_call output format

This commit is contained in:
Yingbei 2024-03-22 15:41:44 -07:00
parent 3cf8de4939
commit 79fd89a62b
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
2 changed files with 6 additions and 2 deletions

View file

@ -341,7 +341,9 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
std::string final_str = "You have access to the following tools:\n"; std::string final_str = "You have access to the following tools:\n";
json type_mapping = { json type_mapping = {
{"string", "str"}, {"string", "str"},
{"integer", "int"},
{"number", "float"}, {"number", "float"},
{"float", "float"},
{"object", "Dict[str, Any]"}, {"object", "Dict[str, Any]"},
{"array", "List"}, {"array", "List"},
{"boolean", "bool"}, {"boolean", "bool"},
@ -592,10 +594,11 @@ static json format_final_response_oaicompat(const json & request, json result, c
json tool_call; json tool_call;
tool_call["id"] = pc["id"]; tool_call["id"] = pc["id"];
tool_call["type"] = "function"; tool_call["type"] = "function";
tool_call["function"] = json{{ tool_call["function"] = json{
{"name" , pc["name"]}, {"name" , pc["name"]},
{"arguments" , pc["kwargs"].dump()}, {"arguments" , pc["kwargs"].dump()},
}}; };
printf("format_final_response_oaicompat: tool_call: %s\n", tool_call.dump().c_str());
oai_format_tool_calls.push_back(tool_call); oai_format_tool_calls.push_back(tool_call);
} }
choices = json::array({json{{"finish_reason", finish_reason}, choices = json::array({json{{"finish_reason", finish_reason},

View file

@ -14518,6 +14518,7 @@ static int32_t llama_chat_apply_template_internal(
// construct the prompt // construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning bool is_inside_turn = true; // skip BOS at the beginning
// ss << "[INST] "; // ss << "[INST] ";
for (auto message : chat) { for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content; std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role); std::string role(message->role);