From eaec0b8748c0e26be01b48e11aeb50ff393937fc Mon Sep 17 00:00:00 2001 From: Yingbei Date: Mon, 25 Mar 2024 18:11:21 -0700 Subject: [PATCH] some clean up --- examples/server/python-parser.hpp | 32 ++++++++++++++++++++----------- examples/server/utils.hpp | 9 --------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/server/python-parser.hpp b/examples/server/python-parser.hpp index 5037e381d..4f35e6b9e 100644 --- a/examples/server/python-parser.hpp +++ b/examples/server/python-parser.hpp @@ -11,20 +11,31 @@ using json = nlohmann::json; // Use an alias for easier access static json parseValue(const std::string& content) { - // Check for numerical value - if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) { - return std::stoi(content); - } // Check for boolean - if (content == "True" || content == "true") { + if (content == "true" || content == "True") { return true; - } else if (content == "False" || content == "false") { + } else if (content == "false" || content == "False") { return false; } - if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') || - (content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) { + // Check for quoted string + if ((content.size() >= 2 && (content.front() == '"' && content.back() == '"')) || + (content.size() >= 2 && (content.front() == '\'' && content.back() == '\''))) { return content.substr(1, content.size() - 2); } + // Attempt to parse as number (int or float) + try { + size_t processed; + // Try integer first + int i = std::stoi(content, &processed); + if (processed == content.size()) return i; + // Then try floating point + double d = std::stod(content, &processed); + if (processed == content.size()) return d; + } catch (const std::invalid_argument& e) { + // Not a number, ignore + } catch (const std::out_of_range& e) { + // Number out of range, ignore + } // TODO: for array, dict, object, function, should further add logic to parse them recursively. return content; } @@ -34,7 +45,7 @@ static json parseValue(const std::string& content) { static void parseFunctionCalls(const TSNode& node, std::vector& calls, const char* source_code, uint32_t indent = 0) { auto type = ts_node_type(node); - printf("type: %s\n", type); + // printf("type: %s\n", type); // Only interested in call_expression nodes at the outermost level if (strcmp(type, "call") == 0) { @@ -92,12 +103,11 @@ static std::vector parsePythonFunctionCalls(std::string source_string) { std::vector calls; std::string delimiter = "<>"; std::string source_code; - printf("source_string: %s\n", source_string.c_str()); + printf("Parsing source_string: %s\n", source_string.c_str()); size_t startPos = source_string.find(delimiter); if (startPos != std::string::npos) { source_code = source_string.substr(startPos + delimiter.length()); } else { - printf("no functions\n"); return calls; } TSParser *parser = ts_parser_new(); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index d94ffe32d..cd9172bcd 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -509,9 +509,6 @@ static json oaicompat_completion_params_parse( std::vector temp_vec; std::unordered_map func_observation_map; for (size_t i = 0; i < body["messages"].size(); ++i) { - printf("body[\"messages\"][%d][\"role\"] = %s\n", i, body["messages"][i]["role"].get().c_str()); - printf("Message: %s\n", body["messages"][i].dump().c_str()); - printf("%d\n", body["messages"][i].contains("tool_calls")); if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) { // insert the observation from the tool call before the next message @@ -548,12 +545,9 @@ static json oaicompat_completion_params_parse( } // else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){ else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){ - printf("Tool call detected\n"); // convert OpenAI function call format to Rubra format std::string tool_call_str = ""; - printf("Tool calls: %s\n", body["messages"][i]["tool_calls"].dump().c_str()); for (const auto & tool_call : body["messages"][i]["tool_calls"]) { - printf("Tool call id: %s\n", tool_call["id"].get().c_str()); std::string func_str = ""; func_observation_map[tool_call["id"].get()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message json args = json::parse(tool_call["function"]["arguments"].get()); // TODO: catch the exceptions @@ -570,7 +564,6 @@ static json oaicompat_completion_params_parse( tool_call_str += func_str; } tool_call_str = std::string("<>") + "[" + tool_call_str + "]"; - printf("Tool call string: %s\n", tool_call_str.c_str()); json function_call; function_call["role"] = "function"; @@ -578,8 +571,6 @@ static json oaicompat_completion_params_parse( temp_vec.push_back(function_call); } else if (body["messages"][i]["role"] == "tool") { - printf("Observation detected\n"); - printf(body["messages"][i].dump().c_str()); std::string tool_call_id = body["messages"][i]["tool_call_id"].get(); if (func_observation_map.find(tool_call_id) != func_observation_map.end()) { func_observation_map[tool_call_id] = body["messages"][i]["content"].get();