a first working version integrated tree_sitter with python parser code
This commit is contained in:
parent
48c02498f2
commit
9bd7dbb17b
7 changed files with 1540 additions and 19 deletions
10
Makefile
10
Makefile
|
@ -669,6 +669,12 @@ clean:
|
|||
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
|
||||
find examples pocs -type f -name "*.o" -delete
|
||||
|
||||
scanner.o: examples/server/tree_sitter/tree-sitter-python/src/scanner.c
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
parser.o: examples/server/tree_sitter/tree-sitter-python/src/parser.c
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
|
@ -735,8 +741,8 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||
|
||||
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
|
||||
|
|
126
examples/server/python-parser.hpp
Normal file
126
examples/server/python-parser.hpp
Normal file
|
@ -0,0 +1,126 @@
|
|||
#include <iostream>
|
||||
#include <tree_sitter/api.h>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "json.hpp" // Include the JSON library
|
||||
|
||||
extern "C" TSLanguage *tree_sitter_python();
|
||||
|
||||
using json = nlohmann::json; // Use an alias for easier access
|
||||
|
||||
|
||||
static json parseValue(const std::string& content) {
|
||||
// Check for numerical value
|
||||
if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) {
|
||||
return std::stoi(content);
|
||||
}
|
||||
// Check for boolean
|
||||
if (content == "True" || content == "true") {
|
||||
return true;
|
||||
} else if (content == "False" || content == "false") {
|
||||
return false;
|
||||
}
|
||||
if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') ||
|
||||
(content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) {
|
||||
return content.substr(1, content.size() - 2);
|
||||
}
|
||||
// TODO: for array, dict, object, function, should further add logic to parse them recursively.
|
||||
return content;
|
||||
}
|
||||
|
||||
|
||||
// Recursive function to parse and create JSON for the outer function calls
|
||||
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
|
||||
auto type = ts_node_type(node);
|
||||
|
||||
printf("type: %s\n", type);
|
||||
// Only interested in call_expression nodes at the outermost level
|
||||
if (strcmp(type, "call") == 0) {
|
||||
|
||||
json call = {
|
||||
{"name", ""},
|
||||
{"args", json::array()},
|
||||
{"kwargs", json::object()}
|
||||
};
|
||||
|
||||
TSNode functionNode = ts_node_child(node, 0); // The function name node
|
||||
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
|
||||
|
||||
// Extract the function name
|
||||
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
|
||||
|
||||
// Loop through the arguments
|
||||
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
|
||||
for (unsigned int i = 0; i < numArgs; ++i) {
|
||||
TSNode argNode = ts_node_named_child(argumentsNode, i);
|
||||
const char* argType = ts_node_type(argNode);
|
||||
|
||||
// Check if the argument is a positional argument or a keyword argument
|
||||
if (strcmp(argType, "argument") == 0 || strcmp(argType, "positional_arguments") == 0 || strcmp(argType, "string") == 0 || strcmp(argType, "integer") == 0 || strcmp(argType, "true") == 0 || strcmp(argType, "false") == 0) {
|
||||
// For simplification, we treat the entire content as the argument
|
||||
std::string value = std::string(source_code + ts_node_start_byte(argNode), ts_node_end_byte(argNode) - ts_node_start_byte(argNode));
|
||||
call["args"].push_back(parseValue(value));
|
||||
} else if (strcmp(argType, "keyword_argument") == 0) {
|
||||
// Extract keyword and value for keyword arguments
|
||||
TSNode keyNode = ts_node_child(argNode, 0); // The key of the kwarg
|
||||
TSNode valueNode = ts_node_child(argNode, 2); // The value of the kwarg, 1 is the symbol `=`
|
||||
|
||||
// if this is 0 then it's a string/integer/boolean, simply parse it
|
||||
// unsigned int numValueNodeChild = ts_node_named_child_count(valueNode);
|
||||
// TODO: if numValueNodeChild != 0 then it's an array/list/object?/function. Need to do something more. However for now we assume this will not happen.
|
||||
|
||||
std::string key = std::string(source_code + ts_node_start_byte(keyNode), ts_node_end_byte(keyNode) - ts_node_start_byte(keyNode));
|
||||
std::string value = std::string(source_code + ts_node_start_byte(valueNode), ts_node_end_byte(valueNode) - ts_node_start_byte(valueNode));
|
||||
call["kwargs"][key] = parseValue(value);
|
||||
}
|
||||
}
|
||||
|
||||
calls.push_back(call);
|
||||
return; // Stop recursion to only process outer function calls
|
||||
}
|
||||
|
||||
// Recurse through all children for other node types
|
||||
unsigned int numChildren = ts_node_child_count(node);
|
||||
for (unsigned int i = 0; i < numChildren; ++i) {
|
||||
TSNode child = ts_node_child(node, i);
|
||||
parseFunctionCalls(child, calls, source_code, indent+1);
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
||||
std::vector<json> calls;
|
||||
std::string delimiter = "<<functions>>";
|
||||
std::string source_code;
|
||||
printf("source_string: %s\n", source_string.c_str());
|
||||
size_t startPos = source_string.find(delimiter);
|
||||
if (startPos != std::string::npos) {
|
||||
source_code = source_string.substr(startPos + delimiter.length());
|
||||
} else {
|
||||
printf("no functions\n");
|
||||
return calls;
|
||||
}
|
||||
TSParser *parser = ts_parser_new();
|
||||
ts_parser_set_language(parser, tree_sitter_python());
|
||||
const char* source_code_cstr = source_code.c_str();
|
||||
TSTree *tree = ts_parser_parse_string(parser, nullptr, source_code_cstr, source_code.length());
|
||||
TSNode root_node = ts_tree_root_node(tree);
|
||||
|
||||
bool has_errors = ts_node_has_error(root_node);
|
||||
|
||||
if (has_errors) {
|
||||
// probably a regular string
|
||||
printf("has errors\n");
|
||||
return calls;
|
||||
}
|
||||
|
||||
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
|
||||
|
||||
// Output the parsed calls
|
||||
|
||||
ts_tree_delete(tree);
|
||||
ts_parser_delete(parser);
|
||||
printf("calls: %s\n", json(calls).dump().c_str());
|
||||
return calls;
|
||||
}
|
||||
|
BIN
examples/server/tree_sitter/libtree-sitter.a
Normal file
BIN
examples/server/tree_sitter/libtree-sitter.a
Normal file
Binary file not shown.
1
examples/server/tree_sitter/tree-sitter-python
Submodule
1
examples/server/tree_sitter/tree-sitter-python
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit b8a4c64121ba66b460cb878e934e3157ecbfb124
|
1262
examples/server/tree_sitter/tree_sitter/api.h
Normal file
1262
examples/server/tree_sitter/tree_sitter/api.h
Normal file
File diff suppressed because it is too large
Load diff
|
@ -4,6 +4,7 @@
|
|||
#include "common.h"
|
||||
|
||||
#include "json.hpp"
|
||||
#include "python-parser.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -421,7 +422,7 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
|
|||
return final_str;
|
||||
}
|
||||
|
||||
std::string default_tool_formatter(const std::vector<json>& tools) {
|
||||
static std::string default_tool_formatter(const std::vector<json>& tools) {
|
||||
std::string toolText = "";
|
||||
std::vector<std::string> toolNames;
|
||||
for (const auto& tool : tools) {
|
||||
|
@ -556,11 +557,6 @@ static json oaicompat_completion_params_parse(
|
|||
}
|
||||
|
||||
|
||||
static json parse_response_for_function_call(const std::string content) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
|
||||
bool stopped_word = result.count("stopped_word") != 0;
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
|
@ -568,6 +564,9 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
|||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
|
||||
std::vector<json> parsed_content = parsePythonFunctionCalls(content);
|
||||
|
||||
|
||||
std::string finish_reason = "length";
|
||||
if (stopped_word || stopped_eos) {
|
||||
finish_reason = "stop";
|
||||
|
@ -579,7 +578,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
|||
{"delta", json::object()}}})
|
||||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"message", json{{"content", parsed_content},
|
||||
{"role", "assistant"}}}}});
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
|
|
@ -29,14 +29,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 161,
|
||||
"execution_count": 163,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" The current weather in Boston is 72 degrees Fahrenheit and it's raining.\n"
|
||||
" <<functions>>[orderUmbrella(brand_name=\"Patagonia\")]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -93,14 +93,15 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 160,
|
||||
"execution_count": 177,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" The current stock price of Tesla (TSLA) is $170 and the current stock price of Google (GOOG) is $138.\n"
|
||||
" <<functions>>[get_stock_fundermentals(symbol=\"TSLA\")]\n",
|
||||
"<<functions>>[get_stock_fundermentals(symbol=\"GOOG\")]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -109,8 +110,8 @@
|
|||
"functions = [\n",
|
||||
" {\"function\":\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_stock_price\",\n",
|
||||
" \"description\": \"Get the current stock price\",\n",
|
||||
" \"name\": \"get_stock_fundermentals\",\n",
|
||||
" \"description\": \"Get the stock fundermentals data\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
|
@ -147,14 +148,140 @@
|
|||
" }}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"user_query = \"What's the stock price of Tesla and Google?\"\n",
|
||||
"user_query = \"What's the stock fundementals of Tesla and google\"\n",
|
||||
"\n",
|
||||
"# \n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>170, 138\"}]\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
|
||||
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
|
||||
"res = get_mistral_rubra_response(user_query, \"mistral_rubra\", functions=functions, msgs=msgs)\n",
|
||||
"print(res.message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 183,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<ast.List object at 0x105649390>\n",
|
||||
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import ast\n",
|
||||
"\n",
|
||||
"input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x=1, b='2', c=[1, 2, {'a': 1, 'b': 2}])]\"\n",
|
||||
"\n",
|
||||
"# Parse the string into an AST\n",
|
||||
"parsed_ast = ast.parse(input_str, mode='eval')\n",
|
||||
"\n",
|
||||
"# Function to convert an AST node to a Python object\n",
|
||||
"def ast_node_to_object(node):\n",
|
||||
" if isinstance(node, ast.Constant):\n",
|
||||
" return node.value\n",
|
||||
" elif isinstance(node, ast.List):\n",
|
||||
" return [ast_node_to_object(n) for n in node.elts]\n",
|
||||
" elif isinstance(node, ast.Dict):\n",
|
||||
" return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n",
|
||||
" elif isinstance(node, ast.Tuple):\n",
|
||||
" return tuple(ast_node_to_object(n) for n in node.elts)\n",
|
||||
" # Add more cases here as needed\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"def find_calls(node):\n",
|
||||
" calls = []\n",
|
||||
" if isinstance(node, ast.Call): # If it's a function call\n",
|
||||
" calls.append(node)\n",
|
||||
" for child in ast.iter_child_nodes(node):\n",
|
||||
" calls.extend(find_calls(child))\n",
|
||||
" return calls\n",
|
||||
"\n",
|
||||
"# Extract all function call nodes\n",
|
||||
"calls = find_calls(parsed_ast.body)\n",
|
||||
"\n",
|
||||
"functions = []\n",
|
||||
"for call in calls:\n",
|
||||
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
|
||||
" function_name = call.func.id\n",
|
||||
" args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n",
|
||||
" kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
|
||||
" functions.append((function_name, args, kwargs))\n",
|
||||
"\n",
|
||||
"print(functions)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': ['func_nested(1, 2)', {'a': \"func_deep('value')\"}]})]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import ast\n",
|
||||
"\n",
|
||||
"input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x=1, b='2', c=[func_nested(1, 2), {'a': func_deep('value')}])]\"\n",
|
||||
"\n",
|
||||
"def ast_node_to_object(node):\n",
|
||||
" if isinstance(node, ast.Constant):\n",
|
||||
" return node.value\n",
|
||||
" elif isinstance(node, ast.List):\n",
|
||||
" return [ast_node_to_object(n) for n in node.elts]\n",
|
||||
" elif isinstance(node, ast.Dict):\n",
|
||||
" return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n",
|
||||
" elif isinstance(node, ast.Tuple):\n",
|
||||
" return tuple(ast_node_to_object(n) for n in node.elts)\n",
|
||||
" elif isinstance(node, ast.Call):\n",
|
||||
" return ast.unparse(node)\n",
|
||||
" # Handle function calls: convert to a representation with the function name and arguments\n",
|
||||
" # func_name = ast_node_to_object(node.func) # Get the function name\n",
|
||||
" # args = [ast_node_to_object(arg) for arg in node.args] # Convert all positional arguments\n",
|
||||
" # kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in node.keywords} # Convert all keyword arguments\n",
|
||||
" # return {\"function\": func_name, \"args\": args, \"kwargs\": kwargs}\n",
|
||||
" elif isinstance(node, ast.Name):\n",
|
||||
" return node.id # Return the identifier name\n",
|
||||
" # Add more cases here as needed\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"# Parse the string into an AST\n",
|
||||
"parsed_ast = ast.parse(input_str, mode='eval')\n",
|
||||
"\n",
|
||||
"# Function to find only the top-level Call nodes\n",
|
||||
"def find_top_level_calls(node):\n",
|
||||
" calls = []\n",
|
||||
" if isinstance(node, ast.Call): # If it's a function call\n",
|
||||
" calls.append(node)\n",
|
||||
" # Do not descend into child nodes to ensure we're only capturing top-level calls\n",
|
||||
" return calls\n",
|
||||
" for child in ast.iter_child_nodes(node):\n",
|
||||
" # Recursively find calls without going into nested calls\n",
|
||||
" calls.extend(find_top_level_calls(child))\n",
|
||||
" return calls\n",
|
||||
"\n",
|
||||
"# Extract all top-level function call nodes\n",
|
||||
"top_level_calls = find_top_level_calls(parsed_ast.body)\n",
|
||||
"\n",
|
||||
"# Process each call node to get the details you want\n",
|
||||
"functions = []\n",
|
||||
"for call in top_level_calls:\n",
|
||||
" if isinstance(call.func, ast.Name): # Ensure it's a named function\n",
|
||||
" function_name = call.func.id\n",
|
||||
" args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n",
|
||||
" kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n",
|
||||
" functions.append((function_name, args, kwargs))\n",
|
||||
"\n",
|
||||
"print(functions)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue