Using ordered_map instead to make sure the function call observations are in the correct order.

This commit is contained in:
Yingbei 2024-04-01 18:20:49 -07:00
parent f2ba723bd9
commit 1eafdc95c8
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
2 changed files with 3 additions and 4 deletions

View file

@ -103,7 +103,7 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
std::vector<json> calls; std::vector<json> calls;
std::string delimiter = "<<functions>>"; std::string delimiter = "<<functions>>";
std::string source_code; std::string source_code;
printf("Parsing source_string: %s\n", source_string.c_str()); printf("Parsing source_string::%s\n", source_string.c_str());
size_t startPos = source_string.find(delimiter); size_t startPos = source_string.find(delimiter);
if (startPos != std::string::npos) { if (startPos != std::string::npos) {
source_code = source_string.substr(startPos + delimiter.length()); source_code = source_string.substr(startPos + delimiter.length());

View file

@ -10,7 +10,6 @@
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <random> #include <random>
#include <unordered_map>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@ -421,7 +420,7 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
for (const auto& def : function_definitions) { for (const auto& def : function_definitions) {
final_str += def + "\n"; final_str += def + "\n";
} }
final_str += "Use the following format if using a tool:\n[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]"; final_str += "Use the following format if using tools:\n<<functions>>[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]";
return final_str; return final_str;
} }
@ -507,7 +506,7 @@ static json oaicompat_completion_params_parse(
// temp_vec.push_back(function_call); // temp_vec.push_back(function_call);
// } // }
std::vector<json> temp_vec; std::vector<json> temp_vec;
std::unordered_map<std::string, std::string> func_observation_map; nlohmann::ordered_map<std::string, std::string> func_observation_map;
for (size_t i = 0; i < body["messages"].size(); ++i) { for (size_t i = 0; i < body["messages"].size(); ++i) {
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) { if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {