Hacky func streaming (#1)
* hacky function call streaming * remove * minor fix to take care of case that the input function has no description or arguments is null * test parser * fix makefile to make sure the order of file linking works for ubuntu gcc/g++ 11.4 * add function name mapping to take care of input function name with hyphen- * add a comment TODO for streaming chunks.
This commit is contained in:
parent
1eafdc95c8
commit
60a01b3ddc
5 changed files with 309 additions and 321 deletions
2
Makefile
2
Makefile
|
@ -753,7 +753,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
|
||||||
|
|
||||||
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
$(CXX) $(CXXFLAGS) $(call GET_OBJ_FILE, $<) $(filter-out %.h %.hpp $<,$^) -Iexamples/server -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||||
|
|
||||||
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
|
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
|
|
|
@ -42,7 +42,7 @@ static json parseValue(const std::string& content) {
|
||||||
|
|
||||||
|
|
||||||
// Recursive function to parse and create JSON for the outer function calls
|
// Recursive function to parse and create JSON for the outer function calls
|
||||||
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
|
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, json tool_name_map, uint32_t indent = 0) {
|
||||||
auto type = ts_node_type(node);
|
auto type = ts_node_type(node);
|
||||||
|
|
||||||
// printf("type: %s\n", type);
|
// printf("type: %s\n", type);
|
||||||
|
@ -60,8 +60,14 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
|
||||||
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
|
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
|
||||||
|
|
||||||
// Extract the function name
|
// Extract the function name
|
||||||
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
|
std::string func_name = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
|
||||||
|
if (tool_name_map.find(func_name) != tool_name_map.end()){
|
||||||
|
call["name"] = tool_name_map[func_name];
|
||||||
|
} else {
|
||||||
|
call["name"] = func_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("function name: %s\n", call["name"].dump().c_str());
|
||||||
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
|
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
|
||||||
for (unsigned int i = 0; i < numArgs; ++i) {
|
for (unsigned int i = 0; i < numArgs; ++i) {
|
||||||
TSNode argNode = ts_node_named_child(argumentsNode, i);
|
TSNode argNode = ts_node_named_child(argumentsNode, i);
|
||||||
|
@ -94,11 +100,11 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
|
||||||
unsigned int numChildren = ts_node_child_count(node);
|
unsigned int numChildren = ts_node_child_count(node);
|
||||||
for (unsigned int i = 0; i < numChildren; ++i) {
|
for (unsigned int i = 0; i < numChildren; ++i) {
|
||||||
TSNode child = ts_node_child(node, i);
|
TSNode child = ts_node_child(node, i);
|
||||||
parseFunctionCalls(child, calls, source_code, indent+1);
|
parseFunctionCalls(child, calls, source_code, tool_name_map, indent+1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
static std::vector<json> parsePythonFunctionCalls(std::string source_string, json tool_name_map) {
|
||||||
// Parse Python function calls from the source code and return a JSON array
|
// Parse Python function calls from the source code and return a JSON array
|
||||||
std::vector<json> calls;
|
std::vector<json> calls;
|
||||||
std::string delimiter = "<<functions>>";
|
std::string delimiter = "<<functions>>";
|
||||||
|
@ -124,7 +130,7 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
||||||
return calls;
|
return calls;
|
||||||
}
|
}
|
||||||
|
|
||||||
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
|
parseFunctionCalls(root_node, calls, source_code_cstr,tool_name_map, 0);
|
||||||
|
|
||||||
ts_tree_delete(tree);
|
ts_tree_delete(tree);
|
||||||
ts_parser_delete(parser);
|
ts_parser_delete(parser);
|
||||||
|
|
|
@ -3230,7 +3230,6 @@ int main(int argc, char ** argv) {
|
||||||
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
||||||
|
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
@ -3249,12 +3248,26 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
|
||||||
|
std::string all_content = "";
|
||||||
while (true) {
|
while (true) {
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
if (!result.error) {
|
|
||||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
|
||||||
|
|
||||||
|
std::string this_content = json_value(result.data, "content", std::string(""));
|
||||||
|
// TODO: this block is just a hacky solution to enable function calling in streaming -- by concat the streaming chunks.
|
||||||
|
// Ideally: If the first a few tokens is <<functions>>, it should keep waiting for all chunks, otherwise do normal stream logic.
|
||||||
|
if (this_content != "") {
|
||||||
|
all_content += this_content;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (all_content != "") {
|
||||||
|
result.data["content"] = all_content;
|
||||||
|
all_content = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!result.error) {
|
||||||
|
std::vector<json> result_array = format_partial_response_oaicompat(data, result.data, completion_id);
|
||||||
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
||||||
if (!it->empty()) {
|
if (!it->empty()) {
|
||||||
const std::string str =
|
const std::string str =
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
|
@ -337,8 +339,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
|
static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
|
||||||
std::string final_str = "You have access to the following tools:\n";
|
std::string final_str = "You have access to the following tools:\n";
|
||||||
|
printf("rubra_format_function_call_str parsing...\n");
|
||||||
json type_mapping = {
|
json type_mapping = {
|
||||||
{"string", "str"},
|
{"string", "str"},
|
||||||
{"integer", "int"},
|
{"integer", "int"},
|
||||||
|
@ -352,10 +355,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
|
||||||
std::vector<std::string> function_definitions;
|
std::vector<std::string> function_definitions;
|
||||||
for (const auto & function : functions) {
|
for (const auto & function : functions) {
|
||||||
const auto &spec = function.contains("function") ? function["function"] : function;
|
const auto &spec = function.contains("function") ? function["function"] : function;
|
||||||
const std::string func_name = spec.value("name", "");
|
std::string func_name = spec.value("name", "");
|
||||||
const std::string description = spec.value("description", "");
|
if (func_name.find('-') != std::string::npos) {
|
||||||
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
|
const std::string origin_func_name = func_name;
|
||||||
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
|
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in python func name
|
||||||
|
tool_name_map[func_name] = origin_func_name;
|
||||||
|
}
|
||||||
|
const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
|
||||||
|
const auto& parameters = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("properties", json({})) : json({});
|
||||||
|
const auto& required_params = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
|
||||||
|
|
||||||
std::vector<std::string> func_args;
|
std::vector<std::string> func_args;
|
||||||
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
|
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
|
||||||
|
@ -481,15 +489,16 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
|
|
||||||
std::string function_str = "";
|
std::string function_str = "";
|
||||||
|
json tool_name_map;
|
||||||
|
|
||||||
if (body.contains("tools") && !body["tools"].empty()) {
|
if (body.contains("tools") && !body["tools"].empty()) {
|
||||||
// function_str = default_tool_formatter(body["tool"]);
|
// function_str = default_tool_formatter(body["tool"]);
|
||||||
function_str = rubra_format_function_call_str(body["tools"]);
|
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
|
||||||
}
|
}
|
||||||
// If 'tool' is not set or empty, check 'functions'
|
// If 'tool' is not set or empty, check 'functions'
|
||||||
else if (body.contains("functions") && !body["functions"].empty()) {
|
else if (body.contains("functions") && !body["functions"].empty()) {
|
||||||
// function_str = default_tool_formatter(body["functions"]);
|
// function_str = default_tool_formatter(body["functions"]);
|
||||||
function_str = rubra_format_function_call_str(body["functions"]);
|
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
|
||||||
}
|
}
|
||||||
printf("\n=============Formatting Input from OPENAI format...============\n");
|
printf("\n=============Formatting Input from OPENAI format...============\n");
|
||||||
if (function_str != "") {
|
if (function_str != "") {
|
||||||
|
@ -607,6 +616,7 @@ static json oaicompat_completion_params_parse(
|
||||||
else {
|
else {
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
||||||
}
|
}
|
||||||
|
llama_params["tool_name_map"] = tool_name_map;
|
||||||
|
|
||||||
// Map OpenAI parameters to llama.cpp parameters
|
// Map OpenAI parameters to llama.cpp parameters
|
||||||
//
|
//
|
||||||
|
@ -661,9 +671,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
|
||||||
std::vector<json> parsed_content = parsePythonFunctionCalls(content);
|
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
|
||||||
|
|
||||||
|
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = "stop";
|
||||||
|
@ -732,7 +741,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
||||||
}
|
}
|
||||||
|
|
||||||
// return value is vector as there is one case where we might need to generate two responses
|
// return value is vector as there is one case where we might need to generate two responses
|
||||||
static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
|
static std::vector<json> format_partial_response_oaicompat(json request ,json result, const std::string & completion_id) {
|
||||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||||
return std::vector<json>({result});
|
return std::vector<json>({result});
|
||||||
}
|
}
|
||||||
|
@ -745,6 +754,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
bool stopped_limit = json_value(result, "stopped_limit", false);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
|
||||||
|
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
|
||||||
|
std::time_t t = std::time(0);
|
||||||
|
if (!parsed_content.empty()) {
|
||||||
|
std::vector<json> res;
|
||||||
|
json choices1 = json::array({json{{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"role", "assistant"}}}}});
|
||||||
|
|
||||||
|
json ret = json{
|
||||||
|
{"choices", choices1},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}
|
||||||
|
};
|
||||||
|
res.push_back(ret);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < parsed_content.size(); ++i) {
|
||||||
|
const auto &pc = parsed_content[i];
|
||||||
|
// Use 'pc' and 'i' as needed
|
||||||
|
json tool_call1;
|
||||||
|
tool_call1["id"] = pc["id"];
|
||||||
|
tool_call1["type"] = "function";
|
||||||
|
tool_call1["index"] = i;
|
||||||
|
tool_call1["function"] = json{
|
||||||
|
{"name" , pc["name"]},
|
||||||
|
{"arguments" , ""},
|
||||||
|
};
|
||||||
|
json ret1 = json{
|
||||||
|
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"tool_calls", std::vector<json>{tool_call1}}}}}})
|
||||||
|
},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}
|
||||||
|
};
|
||||||
|
res.push_back(ret1);
|
||||||
|
json tool_call2;
|
||||||
|
tool_call2["index"] = i;
|
||||||
|
tool_call2["function"] = json{
|
||||||
|
{"name" , ""},
|
||||||
|
{"arguments" , pc["kwargs"].dump()},
|
||||||
|
};
|
||||||
|
json ret2 = json{
|
||||||
|
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"tool_calls", std::vector<json>{tool_call2}}}}}})
|
||||||
|
},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}
|
||||||
|
};
|
||||||
|
res.push_back(ret2);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
std::string finish_reason;
|
std::string finish_reason;
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = "stop";
|
||||||
|
@ -753,7 +822,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
finish_reason = "length";
|
finish_reason = "length";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
|
||||||
|
|
||||||
json choices;
|
json choices;
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Function Definitions"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -10,6 +17,27 @@
|
||||||
"import uuid\n",
|
"import uuid\n",
|
||||||
"from functools import partial\n",
|
"from functools import partial\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def add(args: str):\n",
|
||||||
|
" args = json.loads(args)\n",
|
||||||
|
" return str(float(args[\"a\"]) + float(args[\"b\"]))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def sub(args: str):\n",
|
||||||
|
" args = json.loads(args)\n",
|
||||||
|
" return str(float(args[\"a\"]) - float(args[\"b\"]))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def mult(args: str):\n",
|
||||||
|
" args = json.loads(args)\n",
|
||||||
|
" return str(float(args[\"a\"]) * float(args[\"b\"]))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def div(args: str):\n",
|
||||||
|
" args = json.loads(args)\n",
|
||||||
|
" return str(float(args[\"a\"]) / float(args[\"b\"]))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"def get_oai_response(model, functions, msgs, api_key, base_url):\n",
|
"def get_oai_response(model, functions, msgs, api_key, base_url):\n",
|
||||||
" import openai\n",
|
" import openai\n",
|
||||||
" openai.api_key = api_key ## Add your API key here\n",
|
" openai.api_key = api_key ## Add your API key here\n",
|
||||||
|
@ -52,11 +80,39 @@
|
||||||
" l = len((json.loads(assistant_message.tool_calls[i].function.arguments))[\"location\"])\n",
|
" l = len((json.loads(assistant_message.tool_calls[i].function.arguments))[\"location\"])\n",
|
||||||
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"temprature is {(i+1) * 50 + l } degree\"})\n",
|
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"temprature is {(i+1) * 50 + l } degree\"})\n",
|
||||||
" elif tool_call.function.name == \"calculate_distance\":\n",
|
" elif tool_call.function.name == \"calculate_distance\":\n",
|
||||||
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Distance is {(i+1) * 1700} miles.\"})\n",
|
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Distance is {(i+1) * 50} miles.\"})\n",
|
||||||
" elif tool_call.function.name == \"generate_password\":\n",
|
" elif tool_call.function.name == \"generate_password\":\n",
|
||||||
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Password generated: {uuid.uuid4().hex[:8]}\"})\n",
|
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Password generated: {uuid.uuid4().hex[:8]}\"})\n",
|
||||||
" else:\n",
|
" elif tool_call.function.name == \"orderUmbrella\":\n",
|
||||||
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Order placed. the price is {(i+1) * 10} dollars.\"})\n",
|
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Order placed. the price is {(i+1) * 10} dollars.\"})\n",
|
||||||
|
" elif tool_call.function.name == \"addition\":\n",
|
||||||
|
" msgs.append({\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"name\": \"addition\",\n",
|
||||||
|
" \"content\": add(tool_call.function.arguments),\n",
|
||||||
|
" \"tool_call_id\": tool_call.id\n",
|
||||||
|
" })\n",
|
||||||
|
" elif tool_call.function.name == \"subtraction\":\n",
|
||||||
|
" msgs.append({\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"name\": \"subtraction\",\n",
|
||||||
|
" \"content\": sub(tool_call.function.arguments),\n",
|
||||||
|
" \"tool_call_id\": tool_call.id\n",
|
||||||
|
" })\n",
|
||||||
|
" elif tool_call.function.name == \"multiplication\":\n",
|
||||||
|
" msgs.append({\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"name\": \"multiplication\",\n",
|
||||||
|
" \"content\": mult(tool_call.function.arguments),\n",
|
||||||
|
" \"tool_call_id\": tool_call.id\n",
|
||||||
|
" })\n",
|
||||||
|
" elif tool_call.function.name == \"division\":\n",
|
||||||
|
" msgs.append({\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"name\": \"division\",\n",
|
||||||
|
" \"content\": div(tool_call.function.arguments),\n",
|
||||||
|
" \"tool_call_id\": tool_call.id\n",
|
||||||
|
" })\n",
|
||||||
" \n",
|
" \n",
|
||||||
" return msgs\n",
|
" return msgs\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -64,7 +120,90 @@
|
||||||
" system_prompt = \"You are a helpful assistant.\"\n",
|
" system_prompt = \"You are a helpful assistant.\"\n",
|
||||||
" functions = [\n",
|
" functions = [\n",
|
||||||
" # {\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}},\n",
|
" # {\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}},\n",
|
||||||
" \n",
|
" {\n",
|
||||||
|
" 'type': 'function',\n",
|
||||||
|
" 'function': {\n",
|
||||||
|
" 'name': 'addition',\n",
|
||||||
|
" 'description': \"Adds two numbers together\",\n",
|
||||||
|
" 'parameters': {\n",
|
||||||
|
" 'type': 'object',\n",
|
||||||
|
" 'properties': {\n",
|
||||||
|
" 'a': {\n",
|
||||||
|
" 'description': 'First number to add',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" },\n",
|
||||||
|
" 'b': {\n",
|
||||||
|
" 'description': 'Second number to add',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" 'required': []\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" 'type': 'function',\n",
|
||||||
|
" 'function': {\n",
|
||||||
|
" 'name': 'subtraction',\n",
|
||||||
|
" 'description': \"Subtracts two numbers\",\n",
|
||||||
|
" 'parameters': {\n",
|
||||||
|
" 'type': 'object',\n",
|
||||||
|
" 'properties': {\n",
|
||||||
|
" 'a': {\n",
|
||||||
|
" 'description': 'First number to be subtracted from',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" },\n",
|
||||||
|
" 'b': {\n",
|
||||||
|
" 'description': 'Number to subtract',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" 'required': []\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" 'type': 'function',\n",
|
||||||
|
" 'function': {\n",
|
||||||
|
" 'name': 'multiplication',\n",
|
||||||
|
" 'description': \"Multiply two numbers together\",\n",
|
||||||
|
" 'parameters': {\n",
|
||||||
|
" 'type': 'object',\n",
|
||||||
|
" 'properties': {\n",
|
||||||
|
" 'a': {\n",
|
||||||
|
" 'description': 'First number to multiply',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" },\n",
|
||||||
|
" 'b': {\n",
|
||||||
|
" 'description': 'Second number to multiply',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" 'required': []\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" 'type': 'function',\n",
|
||||||
|
" 'function': {\n",
|
||||||
|
" 'name': 'division',\n",
|
||||||
|
" 'description': \"Divide two numbers\",\n",
|
||||||
|
" 'parameters': {\n",
|
||||||
|
" 'type': 'object',\n",
|
||||||
|
" 'properties': {\n",
|
||||||
|
" 'a': {\n",
|
||||||
|
" 'description': 'First number to use as the dividend',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" },\n",
|
||||||
|
" 'b': {\n",
|
||||||
|
" 'description': 'Second number to use as the divisor',\n",
|
||||||
|
" 'type': 'string'\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" 'required': []\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"type\": \"function\",\n",
|
" \"type\": \"function\",\n",
|
||||||
" \"function\": {\n",
|
" \"function\": {\n",
|
||||||
|
@ -115,18 +254,21 @@
|
||||||
" print(\"\\n[AI calling functions]:\")\n",
|
" print(\"\\n[AI calling functions]:\")\n",
|
||||||
" for tool_call in res_next.message.tool_calls:\n",
|
" for tool_call in res_next.message.tool_calls:\n",
|
||||||
" print(f\"Tool Call: {tool_call.function}\")\n",
|
" print(f\"Tool Call: {tool_call.function}\")\n",
|
||||||
|
" l = 0\n",
|
||||||
" while res_next.message.tool_calls and len(res_next.message.tool_calls) > 0:\n",
|
" while res_next.message.tool_calls and len(res_next.message.tool_calls) > 0:\n",
|
||||||
" msgs = insert_tool_response(res_next, msgs)\n",
|
" msgs = insert_tool_response(res_next, msgs)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" res_next = chat_method(model=\"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
|
" res_next = chat_method(model=\"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
|
||||||
" # for m in msgs:\n",
|
" # for m in msgs:\n",
|
||||||
" # print(m)\n",
|
" # print(m)\n",
|
||||||
|
" print(f\"Loop {l}\")\n",
|
||||||
" if res_next.message.content and len(res_next.message.content) > 0:\n",
|
" if res_next.message.content and len(res_next.message.content) > 0:\n",
|
||||||
" print(\"\\n[AI response]:\\n\", res_next.message.content)\n",
|
" print(\"\\n[AI response]:\\n\", res_next.message.content)\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" print(\"\\n[AI calling functions]:\")\n",
|
" print(\"\\n[AI calling functions]:\")\n",
|
||||||
" for tool_call in res_next.message.tool_calls:\n",
|
" for tool_call in res_next.message.tool_calls:\n",
|
||||||
" print(f\"Tool Call: {tool_call.function}\")\n",
|
" print(f\"Tool Call: {tool_call.function}\")\n",
|
||||||
|
" l += 1\n",
|
||||||
" "
|
" "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -134,12 +276,23 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Function.cpp"
|
"## Multi + Parallel Function Call"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import openai\n",
|
||||||
|
"local_api_key = \"sk-\"\n",
|
||||||
|
"local_base_url = \"http://localhost:8019/v1/\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -149,359 +302,106 @@
|
||||||
"Pointing to URL: http://localhost:8019/v1/\n",
|
"Pointing to URL: http://localhost:8019/v1/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[AI calling functions]:\n",
|
"[AI calling functions]:\n",
|
||||||
"Tool Call: Function(arguments='{\"location\":\"Boston, MA\"}', name='getCurrentWeather')\n",
|
"Tool Call: Function(arguments='{\"destination\":\"Cupertino\",\"mode\":\"driving\",\"origin\":\"San Francisco\"}', name='calculate_distance')\n",
|
||||||
"\n",
|
"Tool Call: Function(arguments='{\"destination\":\"San Francisco\",\"mode\":\"driving\",\"origin\":\"Cupertino\"}', name='calculate_distance')\n",
|
||||||
"Pointing to URL: http://localhost:8019/v1/\n",
|
"Tool Call: Function(arguments='{\"destination\":\"Cupertino\",\"mode\":\"air\",\"origin\":\"San Francisco\"}', name='calculate_distance')\n",
|
||||||
"\n",
|
"Tool Call: Function(arguments='{\"destination\":\"San Francisco\",\"mode\":\"air\",\"origin\":\"Cupertino\"}', name='calculate_distance')\n",
|
||||||
"[AI calling functions]:\n",
|
|
||||||
"Tool Call: Function(arguments='{\"location\":\"Cupertino, CA\"}', name='getCurrentWeather')\n",
|
|
||||||
"\n",
|
|
||||||
"Pointing to URL: http://localhost:8019/v1/\n",
|
|
||||||
"\n",
|
|
||||||
"[AI calling functions]:\n",
|
|
||||||
"Tool Call: Function(arguments='{\"location\":\"Los Angeles, CA\"}', name='getCurrentWeather')\n",
|
|
||||||
"\n",
|
|
||||||
"Pointing to URL: http://localhost:8019/v1/\n",
|
|
||||||
"\n",
|
|
||||||
"[AI calling functions]:\n",
|
|
||||||
"Tool Call: Function(arguments='{\"location\":\"New York City, NY\"}', name='getCurrentWeather')\n",
|
|
||||||
"\n",
|
|
||||||
"Pointing to URL: http://localhost:8019/v1/\n",
|
"Pointing to URL: http://localhost:8019/v1/\n",
|
||||||
|
"Loop 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[AI response]:\n",
|
"[AI response]:\n",
|
||||||
" The distance from Boston, MA to Cupertino, CA is approximately 2,800 miles. The distance from Los Angeles, CA to New York City, NY is approximately 2,800 miles as well.\n"
|
" The distance between San Francisco and Cupertino by driving is 50 miles and 100 miles from Cupertino to San Francisco. When traveling by air, the distance is 150 miles from San Francisco to Cupertino and 200 miles from Cupertino to San Francisco.\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import openai\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
|
||||||
"local_api_key = \"sk-\"\n",
|
|
||||||
"local_base_url = \"http://localhost:8019/v1/\"\n",
|
|
||||||
"get_mistral_rubra_response = partial(get_oai_response, api_key=local_api_key, base_url=local_base_url)\n",
|
"get_mistral_rubra_response = partial(get_oai_response, api_key=local_api_key, base_url=local_base_url)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"user_query = \"calculate the distance from boston to cupertino? and distance from LA to NYC\"\n",
|
"user_query = \"What is the distance between San Francisco and Cupertino by driving and by air from both directions?\"\n",
|
||||||
"msgs = run_completion(get_mistral_rubra_response, user_query)\n",
|
"# user_query = \"What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\"\n",
|
||||||
"# user_query = \"what's the weather in Boston and Cupertino and Chicago?\"\n",
|
"# user_query = \"what's the distance between SF and NYC? Use the result value to multiply by 8, and then divide by 2, and then minus 30\"\n",
|
||||||
"# # user_query = \"order 2 umbrellas\"\n",
|
"msgs = run_completion(get_mistral_rubra_response, user_query)\n"
|
||||||
"# msgs = run_completion(get_mistral_rubra_response, user_query)\n",
|
|
||||||
"# user_query2 = \"now order 3 umbrellas for me\"\n",
|
|
||||||
"# msgs = run_completion(get_mistral_rubra_response, user_query2, msgs)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## OpenAI"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Pointing to URL: https://api.openai.com/v1/\n",
|
"Pointing to URL: http://localhost:8019/v1/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[AI calling functions]:\n",
|
"[AI calling functions]:\n",
|
||||||
"Tool Call: Function(arguments='{\"location\": \"Boston, MA\", \"unit\": \"f\"}', name='getCurrentWeather')\n",
|
"Tool Call: Function(arguments='{\"number_to_buy\":3}', name='orderUmbrella')\n",
|
||||||
"Tool Call: Function(arguments='{\"location\": \"Cupertino, CA\", \"unit\": \"f\"}', name='getCurrentWeather')\n",
|
"Tool Call: Function(arguments='{\"length\":8}', name='generate_password')\n",
|
||||||
"Tool Call: Function(arguments='{\"origin\": \"Boston, MA\", \"destination\": \"Cupertino, CA\", \"mode\": \"driving\"}', name='calculate_distance')\n",
|
"Pointing to URL: http://localhost:8019/v1/\n",
|
||||||
"\n",
|
"Loop 0\n",
|
||||||
"\n",
|
|
||||||
"Pointing to URL: https://api.openai.com/v1/\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"[AI response]:\n",
|
"[AI response]:\n",
|
||||||
" The current weather in Boston, MA is 60°F, and in Cupertino, CA, it is 113°F. The distance from Boston, MA to Cupertino, CA is approximately 5100 miles when traveling by driving.\n"
|
" Your order for 3 umbrellas has been placed, and the total price is 10 dollars. The generated password is 96ddefe8.\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import openai\n",
|
"user_query2 = \"now order 3 umbrellas for me and generate a password of length 8\"\n",
|
||||||
"\n",
|
"msgs = run_completion(get_mistral_rubra_response, user_query2, msgs)"
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"openai_api_key = \"sk-\"\n",
|
|
||||||
"openai_base_url = \"https://api.openai.com/v1/\"\n",
|
|
||||||
"get_openai_response = partial(get_oai_response, api_key=openai_api_key, base_url=openai_base_url)\n",
|
|
||||||
"\n",
|
|
||||||
"# oai_user_query = \"What is the distance between San Francisco and Cupertino by car and by air\"\n",
|
|
||||||
"oai_user_query = \"weather in boston as well as cupertino? and calculate the distance from boston to cupertino\"\n",
|
|
||||||
"# user_query = \"order 2 umbrellas\"\n",
|
|
||||||
"run_completion(get_openai_response, oai_user_query)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"#### IGNORE the following for now."
|
"## Simple Math Chaining"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "TypeError",
|
|
||||||
"evalue": "get_oai_response() got multiple values for argument 'functions'",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn[4], line 48\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;66;03m# \u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\u001b[39;00m\n\u001b[1;32m 47\u001b[0m msgs \u001b[38;5;241m=\u001b[39m [{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msystem\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m:system_prompt} ,{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: user_query},]\n\u001b[0;32m---> 48\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mget_mistral_rubra_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43muser_query\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmistral_rubra\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunctions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunctions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28mprint\u001b[39m(res\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mcontent)\n",
|
|
||||||
"\u001b[0;31mTypeError\u001b[0m: get_oai_response() got multiple values for argument 'functions'"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"system_prompt = \"You are a helpful assistant.\"\n",
|
|
||||||
"functions = [\n",
|
|
||||||
" {\"function\":\n",
|
|
||||||
" {\n",
|
|
||||||
" \"name\": \"get_stock_fundermentals\",\n",
|
|
||||||
" \"description\": \"Get the stock fundermentals data\",\n",
|
|
||||||
" \"parameters\": {\n",
|
|
||||||
" \"type\": \"object\",\n",
|
|
||||||
" \"properties\": {\n",
|
|
||||||
" \"symbol\": {\n",
|
|
||||||
" \"type\": \"string\",\n",
|
|
||||||
" \"description\": \"The stock symbol, e.g. AAPL, GOOG\"\n",
|
|
||||||
" }\n",
|
|
||||||
" },\n",
|
|
||||||
" \"required\": [\n",
|
|
||||||
" \"symbol\"\n",
|
|
||||||
" ]\n",
|
|
||||||
" }\n",
|
|
||||||
" }},\n",
|
|
||||||
" {\"function\":{\n",
|
|
||||||
" \"name\": \"check_word_anagram\",\n",
|
|
||||||
" \"description\": \"Check if two words are anagrams of each other\",\n",
|
|
||||||
" \"parameters\": {\n",
|
|
||||||
" \"type\": \"object\",\n",
|
|
||||||
" \"properties\": {\n",
|
|
||||||
" \"word1\": {\n",
|
|
||||||
" \"type\": \"string\",\n",
|
|
||||||
" \"description\": \"The first word\"\n",
|
|
||||||
" },\n",
|
|
||||||
" \"word2\": {\n",
|
|
||||||
" \"type\": \"string\",\n",
|
|
||||||
" \"description\": \"The second word\"\n",
|
|
||||||
" }\n",
|
|
||||||
" },\n",
|
|
||||||
" \"required\": [\n",
|
|
||||||
" \"word1\",\n",
|
|
||||||
" \"word2\"\n",
|
|
||||||
" ]\n",
|
|
||||||
" }\n",
|
|
||||||
" }}\n",
|
|
||||||
"]\n",
|
|
||||||
"\n",
|
|
||||||
"user_query = \"What's the stock fundementals of Tesla and google\"\n",
|
|
||||||
"\n",
|
|
||||||
"# \n",
|
|
||||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\n",
|
|
||||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
|
|
||||||
"res = get_mistral_rubra_response(user_query, \"mistral_rubra\", functions=functions, msgs=msgs)\n",
|
|
||||||
"print(res.message.content)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"content: <<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func(x= 1, b='2', c=123)]\n",
|
"Pointing to URL: http://localhost:8019/v1/\n",
|
||||||
"[\"get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit')\", \" func(x= 1, b='2', c=123))\"]\n"
|
"\n",
|
||||||
|
"[AI calling functions]:\n",
|
||||||
|
"Tool Call: Function(arguments='{\"a\":\"4\",\"b\":\"6\"}', name='addition')\n",
|
||||||
|
"Tool Call: Function(arguments='{\"a\":\"result\",\"b\":\"2\"}', name='addition')\n",
|
||||||
|
"Tool Call: Function(arguments='{\"a\":\"result\",\"b\":\"5\"}', name='multiplication')\n",
|
||||||
|
"Tool Call: Function(arguments='{\"a\":\"result\",\"b\":\"2\"}', name='division')\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ename": "SyntaxError",
|
"ename": "ValueError",
|
||||||
"evalue": "unterminated string literal (detected at line 1) (<unknown>, line 1)",
|
"evalue": "could not convert string to float: 'result'",
|
||||||
"output_type": "error",
|
"output_type": "error",
|
||||||
"traceback": [
|
"traceback": [
|
||||||
"Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
"\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.12/envs/py310/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3548\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||||
"\u001b[0m Cell \u001b[1;32mIn[47], line 40\u001b[0m\n result_dict = parse_function_call(function_call.strip())\u001b[0m\n",
|
"Cell \u001b[0;32mIn[9], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m user_query3 \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUser tool to help me : What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m msgs \u001b[38;5;241m=\u001b[39m \u001b[43mrun_completion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mget_mistral_rubra_response\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muser_query3\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
"\u001b[0m Cell \u001b[1;32mIn[47], line 22\u001b[0m in \u001b[1;35mparse_function_call\u001b[0m\n parsed_value = ast.literal_eval(value)\u001b[0m\n",
|
"Cell \u001b[0;32mIn[3], line 244\u001b[0m, in \u001b[0;36mrun_completion\u001b[0;34m(chat_method, user_query, msgs)\u001b[0m\n\u001b[1;32m 242\u001b[0m l \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 243\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m res_next\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mtool_calls \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(res_next\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mtool_calls) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 244\u001b[0m msgs \u001b[38;5;241m=\u001b[39m \u001b[43minsert_tool_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres_next\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 246\u001b[0m res_next \u001b[38;5;241m=\u001b[39m chat_method(model\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpt-4-0125-preview\u001b[39m\u001b[38;5;124m\"\u001b[39m, functions\u001b[38;5;241m=\u001b[39mfunctions, msgs\u001b[38;5;241m=\u001b[39mmsgs)\n\u001b[1;32m 247\u001b[0m \u001b[38;5;66;03m# for m in msgs:\u001b[39;00m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# print(m)\u001b[39;00m\n",
|
||||||
"\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.12/lib/python3.10/ast.py:64\u001b[0m in \u001b[1;35mliteral_eval\u001b[0m\n node_or_string = parse(node_or_string.lstrip(\" \\t\"), mode='eval')\u001b[0m\n",
|
"Cell \u001b[0;32mIn[3], line 77\u001b[0m, in \u001b[0;36minsert_tool_response\u001b[0;34m(res, msgs)\u001b[0m\n\u001b[1;32m 72\u001b[0m msgs\u001b[38;5;241m.\u001b[39mappend({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_call_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mstr\u001b[39m(assistant_message\u001b[38;5;241m.\u001b[39mtool_calls[i]\u001b[38;5;241m.\u001b[39mid), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: assistant_message\u001b[38;5;241m.\u001b[39mtool_calls[i]\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39mname, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOrder placed. the price is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m(i\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;250m \u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m10\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dollars.\u001b[39m\u001b[38;5;124m\"\u001b[39m})\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m tool_call\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maddition\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 74\u001b[0m msgs\u001b[38;5;241m.\u001b[39mappend({\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maddition\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m---> 77\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_call\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marguments\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_call_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: tool_call\u001b[38;5;241m.\u001b[39mid\n\u001b[1;32m 79\u001b[0m })\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m tool_call\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msubtraction\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 81\u001b[0m msgs\u001b[38;5;241m.\u001b[39mappend({\n\u001b[1;32m 82\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 83\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msubtraction\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 84\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: sub(tool_call\u001b[38;5;241m.\u001b[39mfunction\u001b[38;5;241m.\u001b[39marguments),\n\u001b[1;32m 85\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_call_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: tool_call\u001b[38;5;241m.\u001b[39mid\n\u001b[1;32m 86\u001b[0m })\n",
|
||||||
"\u001b[0;36m File \u001b[0;32m~/.pyenv/versions/3.10.12/lib/python3.10/ast.py:50\u001b[0;36m in \u001b[0;35mparse\u001b[0;36m\n\u001b[0;31m return compile(source, filename, mode, flags,\u001b[0;36m\n",
|
"Cell \u001b[0;32mIn[3], line 8\u001b[0m, in \u001b[0;36madd\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21madd\u001b[39m(args: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 7\u001b[0m args \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(args)\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mfloat\u001b[39m(args[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m]))\n",
|
||||||
"\u001b[0;36m File \u001b[0;32m<unknown>:1\u001b[0;36m\u001b[0m\n\u001b[0;31m 'Boston\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unterminated string literal (detected at line 1)\n"
|
"\u001b[0;31mValueError\u001b[0m: could not convert string to float: 'result'"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import json\n",
|
"user_query3 = \"User tool to help me : What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\"\n",
|
||||||
"import re\n",
|
"msgs = run_completion(get_mistral_rubra_response, user_query3, msgs)"
|
||||||
"import ast\n",
|
|
||||||
"\n",
|
|
||||||
"content = \"<<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func(x= 1, b='2', c=123)]\"\n",
|
|
||||||
"regex = re.compile(r\"<<functions>>\\[(.*?)\\]\", re.DOTALL)\n",
|
|
||||||
"matches = re.findall(regex, content)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"content:\", content)\n",
|
|
||||||
"\n",
|
|
||||||
"def parse_function_call(call):\n",
|
|
||||||
" func_name, args_str = call.split('(', 1)\n",
|
|
||||||
" args_str = args_str.rstrip(')')\n",
|
|
||||||
" args_list = args_str.split(',')\n",
|
|
||||||
" args_dict = {}\n",
|
|
||||||
" for arg in args_list:\n",
|
|
||||||
" key, value = arg.split('=')\n",
|
|
||||||
" key = key.strip()\n",
|
|
||||||
" value = value.strip()\n",
|
|
||||||
" try:\n",
|
|
||||||
" # Use ast.literal_eval to safely parse the string to its Python type\n",
|
|
||||||
" parsed_value = ast.literal_eval(value)\n",
|
|
||||||
" except ValueError as e:\n",
|
|
||||||
" # If parsing fails, keep the original string. \n",
|
|
||||||
" # This might happen if the value is a string that's not quoted as a Python literal.\n",
|
|
||||||
" print(f\"Error parsing value {value}: {e}\")\n",
|
|
||||||
" parsed_value = value\n",
|
|
||||||
" args_dict[key] = parsed_value\n",
|
|
||||||
" return {\"name\": func_name.strip(), \"arguments\": args_dict}\n",
|
|
||||||
"\n",
|
|
||||||
"result_dicts = []\n",
|
|
||||||
"for match in matches:\n",
|
|
||||||
" # Splitting each function call from the match. We add ')' back because it was used as a delimiter\n",
|
|
||||||
" function_calls = [f\"{func})\" for func in match.split('),') if func]\n",
|
|
||||||
" print(function_calls)\n",
|
|
||||||
" for function_call in function_calls:\n",
|
|
||||||
" # Removing the trailing ')' that was added for the last function call\n",
|
|
||||||
" if function_call.endswith(')'):\n",
|
|
||||||
" function_call = function_call[:-1]\n",
|
|
||||||
" result_dict = parse_function_call(function_call.strip())\n",
|
|
||||||
" result_dicts.append(result_dict)\n",
|
|
||||||
" print(result_dicts)\n",
|
|
||||||
"\n",
|
|
||||||
"res = json.dumps(result_dicts, ensure_ascii=False)\n",
|
|
||||||
"res"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
"source": []
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[{'name': 'get_current_weather', 'args': [], 'kwargs': {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}}, {'name': 'func', 'args': ['cde'], 'kwargs': {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]}}]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import ast\n",
|
|
||||||
"\n",
|
|
||||||
"raw_input_str = \"<<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x= 1, b='2', c=[1, 2, {'a': 1, 'b': 2}])]\"\n",
|
|
||||||
"# raw_input_str = \"<<functions>>[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func( x=1, b='2', c=123)]\"\n",
|
|
||||||
"input_str = raw_input_str.split('<<functions>>')[1]\n",
|
|
||||||
"# Parse the string into an AST\n",
|
|
||||||
"parsed_ast = ast.parse(input_str, mode='eval')\n",
|
|
||||||
"\n",
|
|
||||||
"def find_calls(node):\n",
|
|
||||||
" calls = []\n",
|
|
||||||
" if isinstance(node, ast.Call): # If it's a function call\n",
|
|
||||||
" calls.append(node)\n",
|
|
||||||
" return calls\n",
|
|
||||||
" for child in ast.iter_child_nodes(node):\n",
|
|
||||||
" calls.extend(find_calls(child))\n",
|
|
||||||
" return calls\n",
|
|
||||||
"\n",
|
|
||||||
"# Extract all function call nodes\n",
|
|
||||||
"calls = find_calls(parsed_ast.body)\n",
|
|
||||||
"\n",
|
|
||||||
"functions = []\n",
|
|
||||||
"for call in calls:\n",
|
|
||||||
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
|
|
||||||
" function_name = call.func.id\n",
|
|
||||||
" args = [ast.literal_eval(arg) for arg in call.args] # Convert all positional arguments\n",
|
|
||||||
" kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
|
|
||||||
" functions.append({\"name\": function_name, \"args\": args, \"kwargs\":kwargs})\n",
|
|
||||||
"\n",
|
|
||||||
"print(functions)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': ['func_nested(1, 2)', {'a': \"func_deep('value')\"}]})]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import ast\n",
|
|
||||||
"\n",
|
|
||||||
"input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x= 1, b='2', c=[func_nested(1, 2), {'a': func_deep('value')}])]\"\n",
|
|
||||||
"\n",
|
|
||||||
"def ast_node_to_object(node):\n",
|
|
||||||
" if isinstance(node, ast.Constant):\n",
|
|
||||||
" return node.value\n",
|
|
||||||
" elif isinstance(node, ast.List):\n",
|
|
||||||
" return [ast_node_to_object(n) for n in node.elts]\n",
|
|
||||||
" elif isinstance(node, ast.Dict):\n",
|
|
||||||
" return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n",
|
|
||||||
" elif isinstance(node, ast.Tuple):\n",
|
|
||||||
" return tuple(ast_node_to_object(n) for n in node.elts)\n",
|
|
||||||
" elif isinstance(node, ast.Call):\n",
|
|
||||||
" return ast.unparse(node)\n",
|
|
||||||
" # Handle function calls: convert to a representation with the function name and arguments\n",
|
|
||||||
" # func_name = ast_node_to_object(node.func) # Get the function name\n",
|
|
||||||
" # args = [ast_node_to_object(arg) for arg in node.args] # Convert all positional arguments\n",
|
|
||||||
" # kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in node.keywords} # Convert all keyword arguments\n",
|
|
||||||
" # return {\"function\": func_name, \"args\": args, \"kwargs\": kwargs}\n",
|
|
||||||
" elif isinstance(node, ast.Name):\n",
|
|
||||||
" return node.id # Return the identifier name\n",
|
|
||||||
" # Add more cases here as needed\n",
|
|
||||||
" return None\n",
|
|
||||||
"\n",
|
|
||||||
"# Parse the string into an AST\n",
|
|
||||||
"parsed_ast = ast.parse(input_str, mode='eval')\n",
|
|
||||||
"\n",
|
|
||||||
"# Function to find only the top-level Call nodes\n",
|
|
||||||
"def find_top_level_calls(node):\n",
|
|
||||||
" calls = []\n",
|
|
||||||
" if isinstance(node, ast.Call): # If it's a function call\n",
|
|
||||||
" calls.append(node)\n",
|
|
||||||
" # Do not descend into child nodes to ensure we're only capturing top-level calls\n",
|
|
||||||
" return calls\n",
|
|
||||||
" for child in ast.iter_child_nodes(node):\n",
|
|
||||||
" # Recursively find calls without going into nested calls\n",
|
|
||||||
" calls.extend(find_top_level_calls(child))\n",
|
|
||||||
" return calls\n",
|
|
||||||
"\n",
|
|
||||||
"# Extract all top-level function call nodes\n",
|
|
||||||
"top_level_calls = find_top_level_calls(parsed_ast.body)\n",
|
|
||||||
"\n",
|
|
||||||
"# Process each call node to get the details you want\n",
|
|
||||||
"functions = []\n",
|
|
||||||
"for call in top_level_calls:\n",
|
|
||||||
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
|
|
||||||
" function_name = call.func.id\n",
|
|
||||||
" args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n",
|
|
||||||
" kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
|
|
||||||
" functions.append((function_name, args, kwargs))\n",
|
|
||||||
"\n",
|
|
||||||
"print(functions)\n"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue