some clean up
This commit is contained in:
parent
784fa90cbe
commit
eaec0b8748
2 changed files with 21 additions and 20 deletions
|
@ -11,20 +11,31 @@ using json = nlohmann::json; // Use an alias for easier access
|
||||||
|
|
||||||
|
|
||||||
static json parseValue(const std::string& content) {
|
static json parseValue(const std::string& content) {
|
||||||
// Check for numerical value
|
|
||||||
if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) {
|
|
||||||
return std::stoi(content);
|
|
||||||
}
|
|
||||||
// Check for boolean
|
// Check for boolean
|
||||||
if (content == "True" || content == "true") {
|
if (content == "true" || content == "True") {
|
||||||
return true;
|
return true;
|
||||||
} else if (content == "False" || content == "false") {
|
} else if (content == "false" || content == "False") {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') ||
|
// Check for quoted string
|
||||||
(content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) {
|
if ((content.size() >= 2 && (content.front() == '"' && content.back() == '"')) ||
|
||||||
|
(content.size() >= 2 && (content.front() == '\'' && content.back() == '\''))) {
|
||||||
return content.substr(1, content.size() - 2);
|
return content.substr(1, content.size() - 2);
|
||||||
}
|
}
|
||||||
|
// Attempt to parse as number (int or float)
|
||||||
|
try {
|
||||||
|
size_t processed;
|
||||||
|
// Try integer first
|
||||||
|
int i = std::stoi(content, &processed);
|
||||||
|
if (processed == content.size()) return i;
|
||||||
|
// Then try floating point
|
||||||
|
double d = std::stod(content, &processed);
|
||||||
|
if (processed == content.size()) return d;
|
||||||
|
} catch (const std::invalid_argument& e) {
|
||||||
|
// Not a number, ignore
|
||||||
|
} catch (const std::out_of_range& e) {
|
||||||
|
// Number out of range, ignore
|
||||||
|
}
|
||||||
// TODO: for array, dict, object, function, should further add logic to parse them recursively.
|
// TODO: for array, dict, object, function, should further add logic to parse them recursively.
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
@ -34,7 +45,7 @@ static json parseValue(const std::string& content) {
|
||||||
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
|
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
|
||||||
auto type = ts_node_type(node);
|
auto type = ts_node_type(node);
|
||||||
|
|
||||||
printf("type: %s\n", type);
|
// printf("type: %s\n", type);
|
||||||
// Only interested in call_expression nodes at the outermost level
|
// Only interested in call_expression nodes at the outermost level
|
||||||
if (strcmp(type, "call") == 0) {
|
if (strcmp(type, "call") == 0) {
|
||||||
|
|
||||||
|
@ -92,12 +103,11 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
||||||
std::vector<json> calls;
|
std::vector<json> calls;
|
||||||
std::string delimiter = "<<functions>>";
|
std::string delimiter = "<<functions>>";
|
||||||
std::string source_code;
|
std::string source_code;
|
||||||
printf("source_string: %s\n", source_string.c_str());
|
printf("Parsing source_string: %s\n", source_string.c_str());
|
||||||
size_t startPos = source_string.find(delimiter);
|
size_t startPos = source_string.find(delimiter);
|
||||||
if (startPos != std::string::npos) {
|
if (startPos != std::string::npos) {
|
||||||
source_code = source_string.substr(startPos + delimiter.length());
|
source_code = source_string.substr(startPos + delimiter.length());
|
||||||
} else {
|
} else {
|
||||||
printf("no functions\n");
|
|
||||||
return calls;
|
return calls;
|
||||||
}
|
}
|
||||||
TSParser *parser = ts_parser_new();
|
TSParser *parser = ts_parser_new();
|
||||||
|
|
|
@ -509,9 +509,6 @@ static json oaicompat_completion_params_parse(
|
||||||
std::vector<json> temp_vec;
|
std::vector<json> temp_vec;
|
||||||
std::unordered_map<std::string, std::string> func_observation_map;
|
std::unordered_map<std::string, std::string> func_observation_map;
|
||||||
for (size_t i = 0; i < body["messages"].size(); ++i) {
|
for (size_t i = 0; i < body["messages"].size(); ++i) {
|
||||||
printf("body[\"messages\"][%d][\"role\"] = %s\n", i, body["messages"][i]["role"].get<std::string>().c_str());
|
|
||||||
printf("Message: %s\n", body["messages"][i].dump().c_str());
|
|
||||||
printf("%d\n", body["messages"][i].contains("tool_calls"));
|
|
||||||
|
|
||||||
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {
|
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {
|
||||||
// insert the observation from the tool call before the next message
|
// insert the observation from the tool call before the next message
|
||||||
|
@ -548,12 +545,9 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
// else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){
|
// else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){
|
||||||
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
|
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
|
||||||
printf("Tool call detected\n");
|
|
||||||
// convert OpenAI function call format to Rubra format
|
// convert OpenAI function call format to Rubra format
|
||||||
std::string tool_call_str = "";
|
std::string tool_call_str = "";
|
||||||
printf("Tool calls: %s\n", body["messages"][i]["tool_calls"].dump().c_str());
|
|
||||||
for (const auto & tool_call : body["messages"][i]["tool_calls"]) {
|
for (const auto & tool_call : body["messages"][i]["tool_calls"]) {
|
||||||
printf("Tool call id: %s\n", tool_call["id"].get<std::string>().c_str());
|
|
||||||
std::string func_str = "";
|
std::string func_str = "";
|
||||||
func_observation_map[tool_call["id"].get<std::string>()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message
|
func_observation_map[tool_call["id"].get<std::string>()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message
|
||||||
json args = json::parse(tool_call["function"]["arguments"].get<std::string>()); // TODO: catch the exceptions
|
json args = json::parse(tool_call["function"]["arguments"].get<std::string>()); // TODO: catch the exceptions
|
||||||
|
@ -570,7 +564,6 @@ static json oaicompat_completion_params_parse(
|
||||||
tool_call_str += func_str;
|
tool_call_str += func_str;
|
||||||
}
|
}
|
||||||
tool_call_str = std::string("<<functions>>") + "[" + tool_call_str + "]";
|
tool_call_str = std::string("<<functions>>") + "[" + tool_call_str + "]";
|
||||||
printf("Tool call string: %s\n", tool_call_str.c_str());
|
|
||||||
|
|
||||||
json function_call;
|
json function_call;
|
||||||
function_call["role"] = "function";
|
function_call["role"] = "function";
|
||||||
|
@ -578,8 +571,6 @@ static json oaicompat_completion_params_parse(
|
||||||
temp_vec.push_back(function_call);
|
temp_vec.push_back(function_call);
|
||||||
}
|
}
|
||||||
else if (body["messages"][i]["role"] == "tool") {
|
else if (body["messages"][i]["role"] == "tool") {
|
||||||
printf("Observation detected\n");
|
|
||||||
printf(body["messages"][i].dump().c_str());
|
|
||||||
std::string tool_call_id = body["messages"][i]["tool_call_id"].get<std::string>();
|
std::string tool_call_id = body["messages"][i]["tool_call_id"].get<std::string>();
|
||||||
if (func_observation_map.find(tool_call_id) != func_observation_map.end()) {
|
if (func_observation_map.find(tool_call_id) != func_observation_map.end()) {
|
||||||
func_observation_map[tool_call_id] = body["messages"][i]["content"].get<std::string>();
|
func_observation_map[tool_call_id] = body["messages"][i]["content"].get<std::string>();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue