Hacky func streaming (#1)

* hacky function call streaming

* remove

* minor fix to take care of case that the input function has no description or arguments is null

* test parser

* fix makefile to make sure the order of file linking works for ubuntu gcc/g++ 11.4

* add function name mapping to take care of input function name with hyphen-

* add a comment TODO for streaming chunks.
This commit is contained in:
Yingbei Tong 2024-04-05 20:59:58 +00:00 committed by GitHub
parent 1eafdc95c8
commit 60a01b3ddc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 309 additions and 321 deletions

View file

@ -753,7 +753,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
$(CXX) $(CXXFLAGS) $(call GET_OBJ_FILE, $<) $(filter-out %.h %.hpp $<,$^) -Iexamples/server -o $@ $(LDFLAGS) $(LWINSOCK2)
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

View file

@ -42,7 +42,7 @@ static json parseValue(const std::string& content) {
// Recursive function to parse and create JSON for the outer function calls
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, json tool_name_map, uint32_t indent = 0) {
auto type = ts_node_type(node);
// printf("type: %s\n", type);
@ -60,8 +60,14 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
// Extract the function name
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
std::string func_name = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
if (tool_name_map.find(func_name) != tool_name_map.end()){
call["name"] = tool_name_map[func_name];
} else {
call["name"] = func_name;
}
printf("function name: %s\n", call["name"].dump().c_str());
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
for (unsigned int i = 0; i < numArgs; ++i) {
TSNode argNode = ts_node_named_child(argumentsNode, i);
@ -94,11 +100,11 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
unsigned int numChildren = ts_node_child_count(node);
for (unsigned int i = 0; i < numChildren; ++i) {
TSNode child = ts_node_child(node, i);
parseFunctionCalls(child, calls, source_code, indent+1);
parseFunctionCalls(child, calls, source_code, tool_name_map, indent+1);
}
}
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
static std::vector<json> parsePythonFunctionCalls(std::string source_string, json tool_name_map) {
// Parse Python function calls from the source code and return a JSON array
std::vector<json> calls;
std::string delimiter = "<<functions>>";
@ -124,7 +130,7 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
return calls;
}
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
parseFunctionCalls(root_node, calls, source_code_cstr,tool_name_map, 0);
ts_tree_delete(tree);
ts_parser_delete(parser);

View file

@ -3230,7 +3230,6 @@ int main(int argc, char ** argv) {
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
@ -3249,12 +3248,26 @@ int main(int argc, char ** argv) {
}
ctx_server.queue_results.remove_waiting_task_id(id_task);
} else {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
std::string all_content = "";
while (true) {
server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
std::string this_content = json_value(result.data, "content", std::string(""));
// TODO: this block is just a hacky solution to enable function calling in streaming -- by concat the streaming chunks.
// Ideally: If the first a few tokens is <<functions>>, it should keep waiting for all chunks, otherwise do normal stream logic.
if (this_content != "") {
all_content += this_content;
continue;
} else {
if (all_content != "") {
result.data["content"] = all_content;
all_content = "";
}
}
if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(data, result.data, completion_id);
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
if (!it->empty()) {
const std::string str =

View file

@ -10,6 +10,8 @@
#include <vector>
#include <sstream>
#include <random>
#include <unordered_map>
#include <algorithm>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@ -337,8 +339,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
//
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
std::string final_str = "You have access to the following tools:\n";
printf("rubra_format_function_call_str parsing...\n");
json type_mapping = {
{"string", "str"},
{"integer", "int"},
@ -352,10 +355,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
std::vector<std::string> function_definitions;
for (const auto & function : functions) {
const auto &spec = function.contains("function") ? function["function"] : function;
const std::string func_name = spec.value("name", "");
const std::string description = spec.value("description", "");
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
std::string func_name = spec.value("name", "");
if (func_name.find('-') != std::string::npos) {
const std::string origin_func_name = func_name;
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in python func name
tool_name_map[func_name] = origin_func_name;
}
const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
const auto& parameters = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("properties", json({})) : json({});
const auto& required_params = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
std::vector<std::string> func_args;
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
@ -481,15 +489,16 @@ static json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true;
std::string function_str = "";
json tool_name_map;
if (body.contains("tools") && !body["tools"].empty()) {
// function_str = default_tool_formatter(body["tool"]);
function_str = rubra_format_function_call_str(body["tools"]);
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
}
// If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) {
// function_str = default_tool_formatter(body["functions"]);
function_str = rubra_format_function_call_str(body["functions"]);
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
}
printf("\n=============Formatting Input from OPENAI format...============\n");
if (function_str != "") {
@ -607,6 +616,7 @@ static json oaicompat_completion_params_parse(
else {
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
}
llama_params["tool_name_map"] = tool_name_map;
// Map OpenAI parameters to llama.cpp parameters
//
@ -661,8 +671,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::vector<json> parsed_content = parsePythonFunctionCalls(content);
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
@ -732,7 +741,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
static std::vector<json> format_partial_response_oaicompat(json request ,json result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
@ -745,6 +754,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
std::time_t t = std::time(0);
if (!parsed_content.empty()) {
std::vector<json> res;
json choices1 = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
json ret = json{
{"choices", choices1},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
res.push_back(ret);
for (size_t i = 0; i < parsed_content.size(); ++i) {
const auto &pc = parsed_content[i];
// Use 'pc' and 'i' as needed
json tool_call1;
tool_call1["id"] = pc["id"];
tool_call1["type"] = "function";
tool_call1["index"] = i;
tool_call1["function"] = json{
{"name" , pc["name"]},
{"arguments" , ""},
};
json ret1 = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"tool_calls", std::vector<json>{tool_call1}}}}}})
},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
res.push_back(ret1);
json tool_call2;
tool_call2["index"] = i;
tool_call2["function"] = json{
{"name" , ""},
{"arguments" , pc["kwargs"].dump()},
};
json ret2 = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"tool_calls", std::vector<json>{tool_call2}}}}}})
},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
res.push_back(ret2);
}
return res;
}
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
@ -753,7 +822,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;

View file

@ -1,8 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function Definitions"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -10,6 +17,27 @@
"import uuid\n",
"from functools import partial\n",
"\n",
"\n",
"def add(args: str):\n",
" args = json.loads(args)\n",
" return str(float(args[\"a\"]) + float(args[\"b\"]))\n",
"\n",
"\n",
"def sub(args: str):\n",
" args = json.loads(args)\n",
" return str(float(args[\"a\"]) - float(args[\"b\"]))\n",
"\n",
"\n",
"def mult(args: str):\n",
" args = json.loads(args)\n",
" return str(float(args[\"a\"]) * float(args[\"b\"]))\n",
"\n",
"\n",
"def div(args: str):\n",
" args = json.loads(args)\n",
" return str(float(args[\"a\"]) / float(args[\"b\"]))\n",
"\n",
"\n",
"def get_oai_response(model, functions, msgs, api_key, base_url):\n",
" import openai\n",
" openai.api_key = api_key ## Add your API key here\n",
@ -52,11 +80,39 @@
" l = len((json.loads(assistant_message.tool_calls[i].function.arguments))[\"location\"])\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"temprature is {(i+1) * 50 + l } degree\"})\n",
" elif tool_call.function.name == \"calculate_distance\":\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Distance is {(i+1) * 1700} miles.\"})\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Distance is {(i+1) * 50} miles.\"})\n",
" elif tool_call.function.name == \"generate_password\":\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Password generated: {uuid.uuid4().hex[:8]}\"})\n",
" else:\n",
" elif tool_call.function.name == \"orderUmbrella\":\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Order placed. the price is {(i+1) * 10} dollars.\"})\n",
" elif tool_call.function.name == \"addition\":\n",
" msgs.append({\n",
" \"role\": \"tool\",\n",
" \"name\": \"addition\",\n",
" \"content\": add(tool_call.function.arguments),\n",
" \"tool_call_id\": tool_call.id\n",
" })\n",
" elif tool_call.function.name == \"subtraction\":\n",
" msgs.append({\n",
" \"role\": \"tool\",\n",
" \"name\": \"subtraction\",\n",
" \"content\": sub(tool_call.function.arguments),\n",
" \"tool_call_id\": tool_call.id\n",
" })\n",
" elif tool_call.function.name == \"multiplication\":\n",
" msgs.append({\n",
" \"role\": \"tool\",\n",
" \"name\": \"multiplication\",\n",
" \"content\": mult(tool_call.function.arguments),\n",
" \"tool_call_id\": tool_call.id\n",
" })\n",
" elif tool_call.function.name == \"division\":\n",
" msgs.append({\n",
" \"role\": \"tool\",\n",
" \"name\": \"division\",\n",
" \"content\": div(tool_call.function.arguments),\n",
" \"tool_call_id\": tool_call.id\n",
" })\n",
" \n",
" return msgs\n",
"\n",
@ -64,7 +120,90 @@
" system_prompt = \"You are a helpful assistant.\"\n",
" functions = [\n",
" # {\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}},\n",
" \n",
" {\n",
" 'type': 'function',\n",
" 'function': {\n",
" 'name': 'addition',\n",
" 'description': \"Adds two numbers together\",\n",
" 'parameters': {\n",
" 'type': 'object',\n",
" 'properties': {\n",
" 'a': {\n",
" 'description': 'First number to add',\n",
" 'type': 'string'\n",
" },\n",
" 'b': {\n",
" 'description': 'Second number to add',\n",
" 'type': 'string'\n",
" }\n",
" },\n",
" 'required': []\n",
" }\n",
" }\n",
" },\n",
" {\n",
" 'type': 'function',\n",
" 'function': {\n",
" 'name': 'subtraction',\n",
" 'description': \"Subtracts two numbers\",\n",
" 'parameters': {\n",
" 'type': 'object',\n",
" 'properties': {\n",
" 'a': {\n",
" 'description': 'First number to be subtracted from',\n",
" 'type': 'string'\n",
" },\n",
" 'b': {\n",
" 'description': 'Number to subtract',\n",
" 'type': 'string'\n",
" }\n",
" },\n",
" 'required': []\n",
" }\n",
" }\n",
" },\n",
" {\n",
" 'type': 'function',\n",
" 'function': {\n",
" 'name': 'multiplication',\n",
" 'description': \"Multiply two numbers together\",\n",
" 'parameters': {\n",
" 'type': 'object',\n",
" 'properties': {\n",
" 'a': {\n",
" 'description': 'First number to multiply',\n",
" 'type': 'string'\n",
" },\n",
" 'b': {\n",
" 'description': 'Second number to multiply',\n",
" 'type': 'string'\n",
" }\n",
" },\n",
" 'required': []\n",
" }\n",
" }\n",
" },\n",
" {\n",
" 'type': 'function',\n",
" 'function': {\n",
" 'name': 'division',\n",
" 'description': \"Divide two numbers\",\n",
" 'parameters': {\n",
" 'type': 'object',\n",
" 'properties': {\n",
" 'a': {\n",
" 'description': 'First number to use as the dividend',\n",
" 'type': 'string'\n",
" },\n",
" 'b': {\n",
" 'description': 'Second number to use as the divisor',\n",
" 'type': 'string'\n",
" }\n",
" },\n",
" 'required': []\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
@ -115,18 +254,21 @@
" print(\"\\n[AI calling functions]:\")\n",
" for tool_call in res_next.message.tool_calls:\n",
" print(f\"Tool Call: {tool_call.function}\")\n",
" l = 0\n",
" while res_next.message.tool_calls and len(res_next.message.tool_calls) > 0:\n",
" msgs = insert_tool_response(res_next, msgs)\n",
"\n",
" res_next = chat_method(model=\"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
" # for m in msgs:\n",
" # print(m)\n",
" print(f\"Loop {l}\")\n",
" if res_next.message.content and len(res_next.message.content) > 0:\n",
" print(\"\\n[AI response]:\\n\", res_next.message.content)\n",
" else:\n",
" print(\"\\n[AI calling functions]:\")\n",
" for tool_call in res_next.message.tool_calls:\n",
" print(f\"Tool Call: {tool_call.function}\")\n",
" l += 1\n",
" "
]
},
@ -134,12 +276,23 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function.cpp"
"## Multi + Parallel Function Call"
]
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"local_api_key = \"sk-\"\n",
"local_base_url = \"http://localhost:8019/v1/\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -149,359 +302,106 @@
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Boston, MA\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Cupertino, CA\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Los Angeles, CA\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"New York City, NY\"}', name='getCurrentWeather')\n",
"\n",
"Tool Call: Function(arguments='{\"destination\":\"Cupertino\",\"mode\":\"driving\",\"origin\":\"San Francisco\"}', name='calculate_distance')\n",
"Tool Call: Function(arguments='{\"destination\":\"San Francisco\",\"mode\":\"driving\",\"origin\":\"Cupertino\"}', name='calculate_distance')\n",
"Tool Call: Function(arguments='{\"destination\":\"Cupertino\",\"mode\":\"air\",\"origin\":\"San Francisco\"}', name='calculate_distance')\n",
"Tool Call: Function(arguments='{\"destination\":\"San Francisco\",\"mode\":\"air\",\"origin\":\"Cupertino\"}', name='calculate_distance')\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 0\n",
"\n",
"[AI response]:\n",
" The distance from Boston, MA to Cupertino, CA is approximately 2,800 miles. The distance from Los Angeles, CA to New York City, NY is approximately 2,800 miles as well.\n"
" The distance between San Francisco and Cupertino by driving is 50 miles and 100 miles from Cupertino to San Francisco. When traveling by air, the distance is 150 miles from San Francisco to Cupertino and 200 miles from Cupertino to San Francisco.\n"
]
}
],
"source": [
"import openai\n",
"\n",
"\n",
"local_api_key = \"sk-\"\n",
"local_base_url = \"http://localhost:8019/v1/\"\n",
"get_mistral_rubra_response = partial(get_oai_response, api_key=local_api_key, base_url=local_base_url)\n",
"\n",
"user_query = \"calculate the distance from boston to cupertino? and distance from LA to NYC\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query)\n",
"# user_query = \"what's the weather in Boston and Cupertino and Chicago?\"\n",
"# # user_query = \"order 2 umbrellas\"\n",
"# msgs = run_completion(get_mistral_rubra_response, user_query)\n",
"# user_query2 = \"now order 3 umbrellas for me\"\n",
"# msgs = run_completion(get_mistral_rubra_response, user_query2, msgs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI"
"user_query = \"What is the distance between San Francisco and Cupertino by driving and by air from both directions?\"\n",
"# user_query = \"What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\"\n",
"# user_query = \"what's the distance between SF and NYC? Use the result value to multiply by 8, and then divide by 2, and then minus 30\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query)\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pointing to URL: https://api.openai.com/v1/\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\": \"Boston, MA\", \"unit\": \"f\"}', name='getCurrentWeather')\n",
"Tool Call: Function(arguments='{\"location\": \"Cupertino, CA\", \"unit\": \"f\"}', name='getCurrentWeather')\n",
"Tool Call: Function(arguments='{\"origin\": \"Boston, MA\", \"destination\": \"Cupertino, CA\", \"mode\": \"driving\"}', name='calculate_distance')\n",
"\n",
"\n",
"Pointing to URL: https://api.openai.com/v1/\n",
"Tool Call: Function(arguments='{\"number_to_buy\":3}', name='orderUmbrella')\n",
"Tool Call: Function(arguments='{\"length\":8}', name='generate_password')\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 0\n",
"\n",
"[AI response]:\n",
" The current weather in Boston, MA is 60°F, and in Cupertino, CA, it is 113°F. The distance from Boston, MA to Cupertino, CA is approximately 5100 miles when traveling by driving.\n"
" Your order for 3 umbrellas has been placed, and the total price is 10 dollars. The generated password is 96ddefe8.\n"
]
}
],
"source": [
"import openai\n",
"\n",
"\n",
"\n",
"\n",
"openai_api_key = \"sk-\"\n",
"openai_base_url = \"https://api.openai.com/v1/\"\n",
"get_openai_response = partial(get_oai_response, api_key=openai_api_key, base_url=openai_base_url)\n",
"\n",
"# oai_user_query = \"What is the distance between San Francisco and Cupertino by car and by air\"\n",
"oai_user_query = \"weather in boston as well as cupertino? and calculate the distance from boston to cupertino\"\n",
"# user_query = \"order 2 umbrellas\"\n",
"run_completion(get_openai_response, oai_user_query)"
"user_query2 = \"now order 3 umbrellas for me and generate a password of length 8\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query2, msgs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### IGNORE the following for now."
"## Simple Math Chaining"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "get_oai_response() got multiple values for argument 'functions'",
"name": "stdout",
"output_type": "stream",
"text": [
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"a\":\"4\",\"b\":\"6\"}', name='addition')\n",
"Tool Call: Function(arguments='{\"a\":\"result\",\"b\":\"2\"}', name='addition')\n",
"Tool Call: Function(arguments='{\"a\":\"result\",\"b\":\"5\"}', name='multiplication')\n",
"Tool Call: Function(arguments='{\"a\":\"result\",\"b\":\"2\"}', name='division')\n"
]
},
{
"ename": "ValueError",
"evalue": "could not convert string to float: 'result'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 48\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;66;03m# \u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\u001b[39;00m\n\u001b[1;32m 47\u001b[0m msgs \u001b[38;5;241m=\u001b[39m [{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msystem\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m:system_prompt} ,{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: user_query},]\n\u001b[0;32m---> 48\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mget_mistral_rubra_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43muser_query\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmistral_rubra\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunctions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunctions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28mprint\u001b[39m(res\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mcontent)\n",
"\u001b[0;31mTypeError\u001b[0m: get_oai_response() got multiple values for argument 'functions'"
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m user_query3 \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUser tool to help me : What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m msgs \u001b[38;5;241m=\u001b[39m \u001b[43mrun_completion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mget_mistral_rubra_response\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muser_query3\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[3], line 244\u001b[0m, in \u001b[0;36mrun_completion\u001b[0;34m(chat_method, user_query, msgs)\u001b[0m\n\u001b[1;32m 242\u001b[0m l \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 243\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m res_next\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mtool_calls \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(res_next\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mtool_calls) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 244\u001b[0m msgs \u001b[38;5;241m=\u001b[39m \u001b[43minsert_tool_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres_next\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 246\u001b[0m res_next \u001b[38;5;241m=\u001b[39m chat_method(model\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpt-4-0125-preview\u001b[39m\u001b[38;5;124m\"\u001b[39m, functions\u001b[38;5;241m=\u001b[39mfunctions, msgs\u001b[38;5;241m=\u001b[39mmsgs)\n\u001b[1;32m 247\u001b[0m \u001b[38;5;66;03m# for m in msgs:\u001b[39;00m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# print(m)\u001b[39;00m\n",
"Cell \u001b[0;32mIn[3], line 77\u001b[0m, in \u001b[0;36minsert_tool_response\u001b[0;34m(res, msgs)\u001b[0m\n\u001b[1;32m 72\u001b[0m msgs\u001b[38;5;241m.\u001b[39mappend({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_call_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mstr\u001b[39m(assistant_message\u001b[38;5;241m.\u001b[39mtool_calls[i]\u001b[38;5;241m.\u001b[39mid), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: assistant_message\u001b[38;5;241m.\u001b[39mtool_calls[i]\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39mname, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOrder placed. the price is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m(i\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;250m \u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m10\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dollars.\u001b[39m\u001b[38;5;124m\"\u001b[39m})\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m tool_call\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maddition\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 74\u001b[0m msgs\u001b[38;5;241m.\u001b[39mappend({\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maddition\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m---> 77\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_call\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marguments\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_call_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: tool_call\u001b[38;5;241m.\u001b[39mid\n\u001b[1;32m 79\u001b[0m })\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m tool_call\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msubtraction\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 81\u001b[0m msgs\u001b[38;5;241m.\u001b[39mappend({\n\u001b[1;32m 82\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 83\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msubtraction\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 84\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: sub(tool_call\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39marguments),\n\u001b[1;32m 85\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_call_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: tool_call\u001b[38;5;241m.\u001b[39mid\n\u001b[1;32m 86\u001b[0m })\n",
"Cell \u001b[0;32mIn[3], line 8\u001b[0m, in \u001b[0;36madd\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21madd\u001b[39m(args: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 7\u001b[0m args \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(args)\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mfloat\u001b[39m(args[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m]))\n",
"\u001b[0;31mValueError\u001b[0m: could not convert string to float: 'result'"
]
}
],
"source": [
"system_prompt = \"You are a helpful assistant.\"\n",
"functions = [\n",
" {\"function\":\n",
" {\n",
" \"name\": \"get_stock_fundermentals\",\n",
" \"description\": \"Get the stock fundermentals data\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"symbol\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The stock symbol, e.g. AAPL, GOOG\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"symbol\"\n",
" ]\n",
" }\n",
" }},\n",
" {\"function\":{\n",
" \"name\": \"check_word_anagram\",\n",
" \"description\": \"Check if two words are anagrams of each other\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"word1\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The first word\"\n",
" },\n",
" \"word2\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The second word\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"word1\",\n",
" \"word2\"\n",
" ]\n",
" }\n",
" }}\n",
"]\n",
"\n",
"user_query = \"What's the stock fundementals of Tesla and google\"\n",
"\n",
"# \n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
"res = get_mistral_rubra_response(user_query, \"mistral_rubra\", functions=functions, msgs=msgs)\n",
"print(res.message.content)"
"user_query3 = \"User tool to help me : What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query3, msgs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"content: <<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func(x= 1, b='2', c=123)]\n",
"[\"get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit')\", \" func(x= 1, b='2', c=123))\"]\n"
]
},
{
"ename": "SyntaxError",
"evalue": "unterminated string literal (detected at line 1) (<unknown>, line 1)",
"output_type": "error",
"traceback": [
"Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
"\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.12/envs/py310/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3548\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
"\u001b[0m Cell \u001b[1;32mIn[47], line 40\u001b[0m\n result_dict = parse_function_call(function_call.strip())\u001b[0m\n",
"\u001b[0m Cell \u001b[1;32mIn[47], line 22\u001b[0m in \u001b[1;35mparse_function_call\u001b[0m\n parsed_value = ast.literal_eval(value)\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.12/lib/python3.10/ast.py:64\u001b[0m in \u001b[1;35mliteral_eval\u001b[0m\n node_or_string = parse(node_or_string.lstrip(\" \\t\"), mode='eval')\u001b[0m\n",
"\u001b[0;36m File \u001b[0;32m~/.pyenv/versions/3.10.12/lib/python3.10/ast.py:50\u001b[0;36m in \u001b[0;35mparse\u001b[0;36m\n\u001b[0;31m return compile(source, filename, mode, flags,\u001b[0;36m\n",
"\u001b[0;36m File \u001b[0;32m<unknown>:1\u001b[0;36m\u001b[0m\n\u001b[0;31m 'Boston\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unterminated string literal (detected at line 1)\n"
]
}
],
"source": [
"import json\n",
"import re\n",
"import ast\n",
"\n",
"content = \"<<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func(x= 1, b='2', c=123)]\"\n",
"regex = re.compile(r\"<<functions>>\\[(.*?)\\]\", re.DOTALL)\n",
"matches = re.findall(regex, content)\n",
"\n",
"print(\"content:\", content)\n",
"\n",
"def parse_function_call(call):\n",
" func_name, args_str = call.split('(', 1)\n",
" args_str = args_str.rstrip(')')\n",
" args_list = args_str.split(',')\n",
" args_dict = {}\n",
" for arg in args_list:\n",
" key, value = arg.split('=')\n",
" key = key.strip()\n",
" value = value.strip()\n",
" try:\n",
" # Use ast.literal_eval to safely parse the string to its Python type\n",
" parsed_value = ast.literal_eval(value)\n",
" except ValueError as e:\n",
" # If parsing fails, keep the original string. \n",
" # This might happen if the value is a string that's not quoted as a Python literal.\n",
" print(f\"Error parsing value {value}: {e}\")\n",
" parsed_value = value\n",
" args_dict[key] = parsed_value\n",
" return {\"name\": func_name.strip(), \"arguments\": args_dict}\n",
"\n",
"result_dicts = []\n",
"for match in matches:\n",
" # Splitting each function call from the match. We add ')' back because it was used as a delimiter\n",
" function_calls = [f\"{func})\" for func in match.split('),') if func]\n",
" print(function_calls)\n",
" for function_call in function_calls:\n",
" # Removing the trailing ')' that was added for the last function call\n",
" if function_call.endswith(')'):\n",
" function_call = function_call[:-1]\n",
" result_dict = parse_function_call(function_call.strip())\n",
" result_dicts.append(result_dict)\n",
" print(result_dicts)\n",
"\n",
"res = json.dumps(result_dicts, ensure_ascii=False)\n",
"res"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'name': 'get_current_weather', 'args': [], 'kwargs': {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}}, {'name': 'func', 'args': ['cde'], 'kwargs': {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]}}]\n"
]
}
],
"source": [
"import ast\n",
"\n",
"raw_input_str = \"<<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x= 1, b='2', c=[1, 2, {'a': 1, 'b': 2}])]\"\n",
"# raw_input_str = \"<<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func( x=1, b='2', c=123)]\"\n",
"input_str = raw_input_str.split('<<functions>>')[1]\n",
"# Parse the string into an AST\n",
"parsed_ast = ast.parse(input_str, mode='eval')\n",
"\n",
"def find_calls(node):\n",
" calls = []\n",
" if isinstance(node, ast.Call): # If it's a function call\n",
" calls.append(node)\n",
" return calls\n",
" for child in ast.iter_child_nodes(node):\n",
" calls.extend(find_calls(child))\n",
" return calls\n",
"\n",
"# Extract all function call nodes\n",
"calls = find_calls(parsed_ast.body)\n",
"\n",
"functions = []\n",
"for call in calls:\n",
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
" function_name = call.func.id\n",
" args = [ast.literal_eval(arg) for arg in call.args] # Convert all positional arguments\n",
" kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
" functions.append({\"name\": function_name, \"args\": args, \"kwargs\":kwargs})\n",
"\n",
"print(functions)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': ['func_nested(1, 2)', {'a': \"func_deep('value')\"}]})]\n"
]
}
],
"source": [
"import ast\n",
"\n",
"input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x= 1, b='2', c=[func_nested(1, 2), {'a': func_deep('value')}])]\"\n",
"\n",
"def ast_node_to_object(node):\n",
" if isinstance(node, ast.Constant):\n",
" return node.value\n",
" elif isinstance(node, ast.List):\n",
" return [ast_node_to_object(n) for n in node.elts]\n",
" elif isinstance(node, ast.Dict):\n",
" return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n",
" elif isinstance(node, ast.Tuple):\n",
" return tuple(ast_node_to_object(n) for n in node.elts)\n",
" elif isinstance(node, ast.Call):\n",
" return ast.unparse(node)\n",
" # Handle function calls: convert to a representation with the function name and arguments\n",
" # func_name = ast_node_to_object(node.func) # Get the function name\n",
" # args = [ast_node_to_object(arg) for arg in node.args] # Convert all positional arguments\n",
" # kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in node.keywords} # Convert all keyword arguments\n",
" # return {\"function\": func_name, \"args\": args, \"kwargs\": kwargs}\n",
" elif isinstance(node, ast.Name):\n",
" return node.id # Return the identifier name\n",
" # Add more cases here as needed\n",
" return None\n",
"\n",
"# Parse the string into an AST\n",
"parsed_ast = ast.parse(input_str, mode='eval')\n",
"\n",
"# Function to find only the top-level Call nodes\n",
"def find_top_level_calls(node):\n",
" calls = []\n",
" if isinstance(node, ast.Call): # If it's a function call\n",
" calls.append(node)\n",
" # Do not descend into child nodes to ensure we're only capturing top-level calls\n",
" return calls\n",
" for child in ast.iter_child_nodes(node):\n",
" # Recursively find calls without going into nested calls\n",
" calls.extend(find_top_level_calls(child))\n",
" return calls\n",
"\n",
"# Extract all top-level function call nodes\n",
"top_level_calls = find_top_level_calls(parsed_ast.body)\n",
"\n",
"# Process each call node to get the details you want\n",
"functions = []\n",
"for call in top_level_calls:\n",
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
" function_name = call.func.id\n",
" args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n",
" kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
" functions.append((function_name, args, kwargs))\n",
"\n",
"print(functions)\n"
]
"outputs": [],
"source": []
}
],
"metadata": {