some clean up
This commit is contained in:
parent
784fa90cbe
commit
eaec0b8748
2 changed files with 21 additions and 20 deletions
|
@ -11,20 +11,31 @@ using json = nlohmann::json; // Use an alias for easier access
|
|||
|
||||
|
||||
static json parseValue(const std::string& content) {
|
||||
// Check for numerical value
|
||||
if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) {
|
||||
return std::stoi(content);
|
||||
}
|
||||
// Check for boolean
|
||||
if (content == "True" || content == "true") {
|
||||
if (content == "true" || content == "True") {
|
||||
return true;
|
||||
} else if (content == "False" || content == "false") {
|
||||
} else if (content == "false" || content == "False") {
|
||||
return false;
|
||||
}
|
||||
if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') ||
|
||||
(content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) {
|
||||
// Check for quoted string
|
||||
if ((content.size() >= 2 && (content.front() == '"' && content.back() == '"')) ||
|
||||
(content.size() >= 2 && (content.front() == '\'' && content.back() == '\''))) {
|
||||
return content.substr(1, content.size() - 2);
|
||||
}
|
||||
// Attempt to parse as number (int or float)
|
||||
try {
|
||||
size_t processed;
|
||||
// Try integer first
|
||||
int i = std::stoi(content, &processed);
|
||||
if (processed == content.size()) return i;
|
||||
// Then try floating point
|
||||
double d = std::stod(content, &processed);
|
||||
if (processed == content.size()) return d;
|
||||
} catch (const std::invalid_argument& e) {
|
||||
// Not a number, ignore
|
||||
} catch (const std::out_of_range& e) {
|
||||
// Number out of range, ignore
|
||||
}
|
||||
// TODO: for array, dict, object, function, should further add logic to parse them recursively.
|
||||
return content;
|
||||
}
|
||||
|
@ -34,7 +45,7 @@ static json parseValue(const std::string& content) {
|
|||
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
|
||||
auto type = ts_node_type(node);
|
||||
|
||||
printf("type: %s\n", type);
|
||||
// printf("type: %s\n", type);
|
||||
// Only interested in call_expression nodes at the outermost level
|
||||
if (strcmp(type, "call") == 0) {
|
||||
|
||||
|
@ -92,12 +103,11 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
|||
std::vector<json> calls;
|
||||
std::string delimiter = "<<functions>>";
|
||||
std::string source_code;
|
||||
printf("source_string: %s\n", source_string.c_str());
|
||||
printf("Parsing source_string: %s\n", source_string.c_str());
|
||||
size_t startPos = source_string.find(delimiter);
|
||||
if (startPos != std::string::npos) {
|
||||
source_code = source_string.substr(startPos + delimiter.length());
|
||||
} else {
|
||||
printf("no functions\n");
|
||||
return calls;
|
||||
}
|
||||
TSParser *parser = ts_parser_new();
|
||||
|
|
|
@ -509,9 +509,6 @@ static json oaicompat_completion_params_parse(
|
|||
std::vector<json> temp_vec;
|
||||
std::unordered_map<std::string, std::string> func_observation_map;
|
||||
for (size_t i = 0; i < body["messages"].size(); ++i) {
|
||||
printf("body[\"messages\"][%d][\"role\"] = %s\n", i, body["messages"][i]["role"].get<std::string>().c_str());
|
||||
printf("Message: %s\n", body["messages"][i].dump().c_str());
|
||||
printf("%d\n", body["messages"][i].contains("tool_calls"));
|
||||
|
||||
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {
|
||||
// insert the observation from the tool call before the next message
|
||||
|
@ -548,12 +545,9 @@ static json oaicompat_completion_params_parse(
|
|||
}
|
||||
// else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){
|
||||
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
|
||||
printf("Tool call detected\n");
|
||||
// convert OpenAI function call format to Rubra format
|
||||
std::string tool_call_str = "";
|
||||
printf("Tool calls: %s\n", body["messages"][i]["tool_calls"].dump().c_str());
|
||||
for (const auto & tool_call : body["messages"][i]["tool_calls"]) {
|
||||
printf("Tool call id: %s\n", tool_call["id"].get<std::string>().c_str());
|
||||
std::string func_str = "";
|
||||
func_observation_map[tool_call["id"].get<std::string>()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message
|
||||
json args = json::parse(tool_call["function"]["arguments"].get<std::string>()); // TODO: catch the exceptions
|
||||
|
@ -570,7 +564,6 @@ static json oaicompat_completion_params_parse(
|
|||
tool_call_str += func_str;
|
||||
}
|
||||
tool_call_str = std::string("<<functions>>") + "[" + tool_call_str + "]";
|
||||
printf("Tool call string: %s\n", tool_call_str.c_str());
|
||||
|
||||
json function_call;
|
||||
function_call["role"] = "function";
|
||||
|
@ -578,8 +571,6 @@ static json oaicompat_completion_params_parse(
|
|||
temp_vec.push_back(function_call);
|
||||
}
|
||||
else if (body["messages"][i]["role"] == "tool") {
|
||||
printf("Observation detected\n");
|
||||
printf(body["messages"][i].dump().c_str());
|
||||
std::string tool_call_id = body["messages"][i]["tool_call_id"].get<std::string>();
|
||||
if (func_observation_map.find(tool_call_id) != func_observation_map.end()) {
|
||||
func_observation_map[tool_call_id] = body["messages"][i]["content"].get<std::string>();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue