some clean up

This commit is contained in:
Yingbei 2024-03-25 18:11:21 -07:00
parent 784fa90cbe
commit eaec0b8748
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
2 changed files with 21 additions and 20 deletions

View file

@ -11,20 +11,31 @@ using json = nlohmann::json; // Use an alias for easier access
static json parseValue(const std::string& content) { static json parseValue(const std::string& content) {
// Check for numerical value
if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) {
return std::stoi(content);
}
// Check for boolean // Check for boolean
if (content == "True" || content == "true") { if (content == "true" || content == "True") {
return true; return true;
} else if (content == "False" || content == "false") { } else if (content == "false" || content == "False") {
return false; return false;
} }
if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') || // Check for quoted string
(content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) { if ((content.size() >= 2 && (content.front() == '"' && content.back() == '"')) ||
(content.size() >= 2 && (content.front() == '\'' && content.back() == '\''))) {
return content.substr(1, content.size() - 2); return content.substr(1, content.size() - 2);
} }
// Attempt to parse as number (int or float)
try {
size_t processed;
// Try integer first
int i = std::stoi(content, &processed);
if (processed == content.size()) return i;
// Then try floating point
double d = std::stod(content, &processed);
if (processed == content.size()) return d;
} catch (const std::invalid_argument& e) {
// Not a number, ignore
} catch (const std::out_of_range& e) {
// Number out of range, ignore
}
// TODO: for array, dict, object, function, should further add logic to parse them recursively. // TODO: for array, dict, object, function, should further add logic to parse them recursively.
return content; return content;
} }
@ -34,7 +45,7 @@ static json parseValue(const std::string& content) {
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) { static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
auto type = ts_node_type(node); auto type = ts_node_type(node);
printf("type: %s\n", type); // printf("type: %s\n", type);
// Only interested in call_expression nodes at the outermost level // Only interested in call_expression nodes at the outermost level
if (strcmp(type, "call") == 0) { if (strcmp(type, "call") == 0) {
@ -92,12 +103,11 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
std::vector<json> calls; std::vector<json> calls;
std::string delimiter = "<<functions>>"; std::string delimiter = "<<functions>>";
std::string source_code; std::string source_code;
printf("source_string: %s\n", source_string.c_str()); printf("Parsing source_string: %s\n", source_string.c_str());
size_t startPos = source_string.find(delimiter); size_t startPos = source_string.find(delimiter);
if (startPos != std::string::npos) { if (startPos != std::string::npos) {
source_code = source_string.substr(startPos + delimiter.length()); source_code = source_string.substr(startPos + delimiter.length());
} else { } else {
printf("no functions\n");
return calls; return calls;
} }
TSParser *parser = ts_parser_new(); TSParser *parser = ts_parser_new();

View file

@ -509,9 +509,6 @@ static json oaicompat_completion_params_parse(
std::vector<json> temp_vec; std::vector<json> temp_vec;
std::unordered_map<std::string, std::string> func_observation_map; std::unordered_map<std::string, std::string> func_observation_map;
for (size_t i = 0; i < body["messages"].size(); ++i) { for (size_t i = 0; i < body["messages"].size(); ++i) {
printf("body[\"messages\"][%d][\"role\"] = %s\n", i, body["messages"][i]["role"].get<std::string>().c_str());
printf("Message: %s\n", body["messages"][i].dump().c_str());
printf("%d\n", body["messages"][i].contains("tool_calls"));
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) { if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {
// insert the observation from the tool call before the next message // insert the observation from the tool call before the next message
@ -548,12 +545,9 @@ static json oaicompat_completion_params_parse(
} }
// else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){ // else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){ else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
printf("Tool call detected\n");
// convert OpenAI function call format to Rubra format // convert OpenAI function call format to Rubra format
std::string tool_call_str = ""; std::string tool_call_str = "";
printf("Tool calls: %s\n", body["messages"][i]["tool_calls"].dump().c_str());
for (const auto & tool_call : body["messages"][i]["tool_calls"]) { for (const auto & tool_call : body["messages"][i]["tool_calls"]) {
printf("Tool call id: %s\n", tool_call["id"].get<std::string>().c_str());
std::string func_str = ""; std::string func_str = "";
func_observation_map[tool_call["id"].get<std::string>()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message func_observation_map[tool_call["id"].get<std::string>()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message
json args = json::parse(tool_call["function"]["arguments"].get<std::string>()); // TODO: catch the exceptions json args = json::parse(tool_call["function"]["arguments"].get<std::string>()); // TODO: catch the exceptions
@ -570,7 +564,6 @@ static json oaicompat_completion_params_parse(
tool_call_str += func_str; tool_call_str += func_str;
} }
tool_call_str = std::string("<<functions>>") + "[" + tool_call_str + "]"; tool_call_str = std::string("<<functions>>") + "[" + tool_call_str + "]";
printf("Tool call string: %s\n", tool_call_str.c_str());
json function_call; json function_call;
function_call["role"] = "function"; function_call["role"] = "function";
@ -578,8 +571,6 @@ static json oaicompat_completion_params_parse(
temp_vec.push_back(function_call); temp_vec.push_back(function_call);
} }
else if (body["messages"][i]["role"] == "tool") { else if (body["messages"][i]["role"] == "tool") {
printf("Observation detected\n");
printf(body["messages"][i].dump().c_str());
std::string tool_call_id = body["messages"][i]["tool_call_id"].get<std::string>(); std::string tool_call_id = body["messages"][i]["tool_call_id"].get<std::string>();
if (func_observation_map.find(tool_call_id) != func_observation_map.end()) { if (func_observation_map.find(tool_call_id) != func_observation_map.end()) {
func_observation_map[tool_call_id] = body["messages"][i]["content"].get<std::string>(); func_observation_map[tool_call_id] = body["messages"][i]["content"].get<std::string>();