a first working version integrated tree_sitter with python parser code

This commit is contained in:
Yingbei 2024-03-18 18:21:33 -07:00
parent 48c02498f2
commit 9bd7dbb17b
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
7 changed files with 1540 additions and 19 deletions

View file

@ -669,6 +669,12 @@ clean:
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS) rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
find examples pocs -type f -name "*.o" -delete find examples pocs -type f -name "*.o" -delete
scanner.o: examples/server/tree_sitter/tree-sitter-python/src/scanner.c
$(CC) $(CFLAGS) -c $< -o $@
parser.o: examples/server/tree_sitter/tree-sitter-python/src/parser.c
$(CC) $(CFLAGS) -c $< -o $@
# #
# Examples # Examples
# #
@ -735,8 +741,8 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS) gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)

View file

@ -0,0 +1,126 @@
#include <iostream>
#include <tree_sitter/api.h>
#include <cstring>
#include <vector>
#include <string>
#include "json.hpp" // Include the JSON library
extern "C" TSLanguage *tree_sitter_python();
using json = nlohmann::json; // Use an alias for easier access
static json parseValue(const std::string& content) {
// Check for numerical value
if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) {
return std::stoi(content);
}
// Check for boolean
if (content == "True" || content == "true") {
return true;
} else if (content == "False" || content == "false") {
return false;
}
if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') ||
(content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) {
return content.substr(1, content.size() - 2);
}
// TODO: for array, dict, object, function, should further add logic to parse them recursively.
return content;
}
// Recursive function to parse and create JSON for the outer function calls
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
auto type = ts_node_type(node);
printf("type: %s\n", type);
// Only interested in call_expression nodes at the outermost level
if (strcmp(type, "call") == 0) {
json call = {
{"name", ""},
{"args", json::array()},
{"kwargs", json::object()}
};
TSNode functionNode = ts_node_child(node, 0); // The function name node
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
// Extract the function name
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
// Loop through the arguments
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
for (unsigned int i = 0; i < numArgs; ++i) {
TSNode argNode = ts_node_named_child(argumentsNode, i);
const char* argType = ts_node_type(argNode);
// Check if the argument is a positional argument or a keyword argument
if (strcmp(argType, "argument") == 0 || strcmp(argType, "positional_arguments") == 0 || strcmp(argType, "string") == 0 || strcmp(argType, "integer") == 0 || strcmp(argType, "true") == 0 || strcmp(argType, "false") == 0) {
// For simplification, we treat the entire content as the argument
std::string value = std::string(source_code + ts_node_start_byte(argNode), ts_node_end_byte(argNode) - ts_node_start_byte(argNode));
call["args"].push_back(parseValue(value));
} else if (strcmp(argType, "keyword_argument") == 0) {
// Extract keyword and value for keyword arguments
TSNode keyNode = ts_node_child(argNode, 0); // The key of the kwarg
TSNode valueNode = ts_node_child(argNode, 2); // The value of the kwarg, 1 is the symbol `=`
// if this is 0 then it's a string/integer/boolean, simply parse it
// unsigned int numValueNodeChild = ts_node_named_child_count(valueNode);
// TODO: if numValueNodeChild != 0 then it's an array/list/object?/function. Need to do something more. However for now we assume this will not happen.
std::string key = std::string(source_code + ts_node_start_byte(keyNode), ts_node_end_byte(keyNode) - ts_node_start_byte(keyNode));
std::string value = std::string(source_code + ts_node_start_byte(valueNode), ts_node_end_byte(valueNode) - ts_node_start_byte(valueNode));
call["kwargs"][key] = parseValue(value);
}
}
calls.push_back(call);
return; // Stop recursion to only process outer function calls
}
// Recurse through all children for other node types
unsigned int numChildren = ts_node_child_count(node);
for (unsigned int i = 0; i < numChildren; ++i) {
TSNode child = ts_node_child(node, i);
parseFunctionCalls(child, calls, source_code, indent+1);
}
}
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
std::vector<json> calls;
std::string delimiter = "<<functions>>";
std::string source_code;
printf("source_string: %s\n", source_string.c_str());
size_t startPos = source_string.find(delimiter);
if (startPos != std::string::npos) {
source_code = source_string.substr(startPos + delimiter.length());
} else {
printf("no functions\n");
return calls;
}
TSParser *parser = ts_parser_new();
ts_parser_set_language(parser, tree_sitter_python());
const char* source_code_cstr = source_code.c_str();
TSTree *tree = ts_parser_parse_string(parser, nullptr, source_code_cstr, source_code.length());
TSNode root_node = ts_tree_root_node(tree);
bool has_errors = ts_node_has_error(root_node);
if (has_errors) {
// probably a regular string
printf("has errors\n");
return calls;
}
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
// Output the parsed calls
ts_tree_delete(tree);
ts_parser_delete(parser);
printf("calls: %s\n", json(calls).dump().c_str());
return calls;
}

Binary file not shown.

@ -0,0 +1 @@
Subproject commit b8a4c64121ba66b460cb878e934e3157ecbfb124

File diff suppressed because it is too large Load diff

View file

@ -4,6 +4,7 @@
#include "common.h" #include "common.h"
#include "json.hpp" #include "json.hpp"
#include "python-parser.hpp"
#include <string> #include <string>
#include <vector> #include <vector>
@ -421,7 +422,7 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
return final_str; return final_str;
} }
std::string default_tool_formatter(const std::vector<json>& tools) { static std::string default_tool_formatter(const std::vector<json>& tools) {
std::string toolText = ""; std::string toolText = "";
std::vector<std::string> toolNames; std::vector<std::string> toolNames;
for (const auto& tool : tools) { for (const auto& tool : tools) {
@ -556,11 +557,6 @@ static json oaicompat_completion_params_parse(
} }
static json parse_response_for_function_call(const std::string content) {
}
static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) { static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
@ -568,6 +564,9 @@ static json format_final_response_oaicompat(const json & request, json result, c
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string("")); std::string content = json_value(result, "content", std::string(""));
std::vector<json> parsed_content = parsePythonFunctionCalls(content);
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stopped_word || stopped_eos) { if (stopped_word || stopped_eos) {
finish_reason = "stop"; finish_reason = "stop";
@ -579,7 +578,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
{"delta", json::object()}}}) {"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason}, : json::array({json{{"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", json{{"content", content}, {"message", json{{"content", parsed_content},
{"role", "assistant"}}}}}); {"role", "assistant"}}}}});
std::time_t t = std::time(0); std::time_t t = std::time(0);

View file

@ -29,14 +29,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 161, "execution_count": 163,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" The current weather in Boston is 72 degrees Fahrenheit and it's raining.\n" " <<functions>>[orderUmbrella(brand_name=\"Patagonia\")]\n"
] ]
} }
], ],
@ -93,14 +93,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 160, "execution_count": 177,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" The current stock price of Tesla (TSLA) is $170 and the current stock price of Google (GOOG) is $138.\n" " <<functions>>[get_stock_fundermentals(symbol=\"TSLA\")]\n",
"<<functions>>[get_stock_fundermentals(symbol=\"GOOG\")]\n"
] ]
} }
], ],
@ -109,8 +110,8 @@
"functions = [\n", "functions = [\n",
" {\"function\":\n", " {\"function\":\n",
" {\n", " {\n",
" \"name\": \"get_stock_price\",\n", " \"name\": \"get_stock_fundermentals\",\n",
" \"description\": \"Get the current stock price\",\n", " \"description\": \"Get the stock fundermentals data\",\n",
" \"parameters\": {\n", " \"parameters\": {\n",
" \"type\": \"object\",\n", " \"type\": \"object\",\n",
" \"properties\": {\n", " \"properties\": {\n",
@ -147,14 +148,140 @@
" }}\n", " }}\n",
"]\n", "]\n",
"\n", "\n",
"user_query = \"What's the stock price of Tesla and Google?\"\n", "user_query = \"What's the stock fundementals of Tesla and google\"\n",
"\n", "\n",
"# \n", "# \n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>170, 138\"}]\n", "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n", "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", "res = get_mistral_rubra_response(user_query, \"mistral_rubra\", functions=functions, msgs=msgs)\n",
"print(res.message.content)" "print(res.message.content)"
] ]
},
{
"cell_type": "code",
"execution_count": 183,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<ast.List object at 0x105649390>\n",
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
]
}
],
"source": [
"import ast\n",
"\n",
"input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x=1, b='2', c=[1, 2, {'a': 1, 'b': 2}])]\"\n",
"\n",
"# Parse the string into an AST\n",
"parsed_ast = ast.parse(input_str, mode='eval')\n",
"\n",
"# Function to convert an AST node to a Python object\n",
"def ast_node_to_object(node):\n",
" if isinstance(node, ast.Constant):\n",
" return node.value\n",
" elif isinstance(node, ast.List):\n",
" return [ast_node_to_object(n) for n in node.elts]\n",
" elif isinstance(node, ast.Dict):\n",
" return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n",
" elif isinstance(node, ast.Tuple):\n",
" return tuple(ast_node_to_object(n) for n in node.elts)\n",
" # Add more cases here as needed\n",
" return None\n",
"\n",
"def find_calls(node):\n",
" calls = []\n",
" if isinstance(node, ast.Call): # If it's a function call\n",
" calls.append(node)\n",
" for child in ast.iter_child_nodes(node):\n",
" calls.extend(find_calls(child))\n",
" return calls\n",
"\n",
"# Extract all function call nodes\n",
"calls = find_calls(parsed_ast.body)\n",
"\n",
"functions = []\n",
"for call in calls:\n",
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
" function_name = call.func.id\n",
" args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n",
" kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
" functions.append((function_name, args, kwargs))\n",
"\n",
"print(functions)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': ['func_nested(1, 2)', {'a': \"func_deep('value')\"}]})]\n"
]
}
],
"source": [
"import ast\n",
"\n",
"input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x=1, b='2', c=[func_nested(1, 2), {'a': func_deep('value')}])]\"\n",
"\n",
"def ast_node_to_object(node):\n",
" if isinstance(node, ast.Constant):\n",
" return node.value\n",
" elif isinstance(node, ast.List):\n",
" return [ast_node_to_object(n) for n in node.elts]\n",
" elif isinstance(node, ast.Dict):\n",
" return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n",
" elif isinstance(node, ast.Tuple):\n",
" return tuple(ast_node_to_object(n) for n in node.elts)\n",
" elif isinstance(node, ast.Call):\n",
" return ast.unparse(node)\n",
" # Handle function calls: convert to a representation with the function name and arguments\n",
" # func_name = ast_node_to_object(node.func) # Get the function name\n",
" # args = [ast_node_to_object(arg) for arg in node.args] # Convert all positional arguments\n",
" # kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in node.keywords} # Convert all keyword arguments\n",
" # return {\"function\": func_name, \"args\": args, \"kwargs\": kwargs}\n",
" elif isinstance(node, ast.Name):\n",
" return node.id # Return the identifier name\n",
" # Add more cases here as needed\n",
" return None\n",
"\n",
"# Parse the string into an AST\n",
"parsed_ast = ast.parse(input_str, mode='eval')\n",
"\n",
"# Function to find only the top-level Call nodes\n",
"def find_top_level_calls(node):\n",
" calls = []\n",
" if isinstance(node, ast.Call): # If it's a function call\n",
" calls.append(node)\n",
" # Do not descend into child nodes to ensure we're only capturing top-level calls\n",
" return calls\n",
" for child in ast.iter_child_nodes(node):\n",
" # Recursively find calls without going into nested calls\n",
" calls.extend(find_top_level_calls(child))\n",
" return calls\n",
"\n",
"# Extract all top-level function call nodes\n",
"top_level_calls = find_top_level_calls(parsed_ast.body)\n",
"\n",
"# Process each call node to get the details you want\n",
"functions = []\n",
"for call in top_level_calls:\n",
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
" function_name = call.func.id\n",
" args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n",
" kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
" functions.append((function_name, args, kwargs))\n",
"\n",
"print(functions)\n"
]
} }
], ],
"metadata": { "metadata": {