diff --git a/Makefile b/Makefile index c8fd3f5c5..60f81870c 100644 --- a/Makefile +++ b/Makefile @@ -669,6 +669,12 @@ clean: rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS) find examples pocs -type f -name "*.o" -delete +scanner.o: examples/server/tree_sitter/tree-sitter-python/src/scanner.c + $(CC) $(CFLAGS) -c $< -o $@ + +parser.o: examples/server/tree_sitter/tree-sitter-python/src/parser.c + $(CC) $(CFLAGS) -c $< -o $@ + # # Examples # @@ -735,8 +741,8 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) +server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) gguf: examples/gguf/gguf.cpp ggml.o $(OBJS) diff --git a/examples/server/python-parser.hpp b/examples/server/python-parser.hpp new file mode 100644 index 000000000..83d70f219 --- /dev/null +++ b/examples/server/python-parser.hpp @@ -0,0 +1,126 @@ +#include +#include +#include +#include +#include +#include "json.hpp" // Include the JSON library + +extern "C" TSLanguage *tree_sitter_python(); + +using json = nlohmann::json; // Use an alias for easier access + + +static json parseValue(const std::string& content) { + // Check for numerical value + if (!content.empty() && std::all_of(content.begin(), content.end(), ::isdigit)) { + return std::stoi(content); + } + // Check for boolean + if (content == "True" || content == "true") { + return true; + } else if (content == "False" || content == "false") { + return false; + } + if ((content.size() >= 2 && content.front() == '"' && content.back() == '"') || + (content.size() >= 2 && content.front() == '\'' && content.back() == '\'')) { + return content.substr(1, content.size() - 2); + } + // TODO: for array, dict, object, function, should further add logic to parse them recursively. + return content; +} + + +// Recursive function to parse and create JSON for the outer function calls +static void parseFunctionCalls(const TSNode& node, std::vector& calls, const char* source_code, uint32_t indent = 0) { + auto type = ts_node_type(node); + + printf("type: %s\n", type); + // Only interested in call_expression nodes at the outermost level + if (strcmp(type, "call") == 0) { + + json call = { + {"name", ""}, + {"args", json::array()}, + {"kwargs", json::object()} + }; + + TSNode functionNode = ts_node_child(node, 0); // The function name node + TSNode argumentsNode = ts_node_child(node, 1); // The arguments node + + // Extract the function name + call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode)); + + // Loop through the arguments + unsigned int numArgs = ts_node_named_child_count(argumentsNode); + for (unsigned int i = 0; i < numArgs; ++i) { + TSNode argNode = ts_node_named_child(argumentsNode, i); + const char* argType = ts_node_type(argNode); + + // Check if the argument is a positional argument or a keyword argument + if (strcmp(argType, "argument") == 0 || strcmp(argType, "positional_arguments") == 0 || strcmp(argType, "string") == 0 || strcmp(argType, "integer") == 0 || strcmp(argType, "true") == 0 || strcmp(argType, "false") == 0) { + // For simplification, we treat the entire content as the argument + std::string value = std::string(source_code + ts_node_start_byte(argNode), ts_node_end_byte(argNode) - ts_node_start_byte(argNode)); + call["args"].push_back(parseValue(value)); + } else if (strcmp(argType, "keyword_argument") == 0) { + // Extract keyword and value for keyword arguments + TSNode keyNode = ts_node_child(argNode, 0); // The key of the kwarg + TSNode valueNode = ts_node_child(argNode, 2); // The value of the kwarg, 1 is the symbol `=` + + // if this is 0 then it's a string/integer/boolean, simply parse it + // unsigned int numValueNodeChild = ts_node_named_child_count(valueNode); + // TODO: if numValueNodeChild != 0 then it's an array/list/object?/function. Need to do something more. However for now we assume this will not happen. + + std::string key = std::string(source_code + ts_node_start_byte(keyNode), ts_node_end_byte(keyNode) - ts_node_start_byte(keyNode)); + std::string value = std::string(source_code + ts_node_start_byte(valueNode), ts_node_end_byte(valueNode) - ts_node_start_byte(valueNode)); + call["kwargs"][key] = parseValue(value); + } + } + + calls.push_back(call); + return; // Stop recursion to only process outer function calls + } + + // Recurse through all children for other node types + unsigned int numChildren = ts_node_child_count(node); + for (unsigned int i = 0; i < numChildren; ++i) { + TSNode child = ts_node_child(node, i); + parseFunctionCalls(child, calls, source_code, indent+1); + } +} + +static std::vector parsePythonFunctionCalls(std::string source_string) { + std::vector calls; + std::string delimiter = "<>"; + std::string source_code; + printf("source_string: %s\n", source_string.c_str()); + size_t startPos = source_string.find(delimiter); + if (startPos != std::string::npos) { + source_code = source_string.substr(startPos + delimiter.length()); + } else { + printf("no functions\n"); + return calls; + } + TSParser *parser = ts_parser_new(); + ts_parser_set_language(parser, tree_sitter_python()); + const char* source_code_cstr = source_code.c_str(); + TSTree *tree = ts_parser_parse_string(parser, nullptr, source_code_cstr, source_code.length()); + TSNode root_node = ts_tree_root_node(tree); + + bool has_errors = ts_node_has_error(root_node); + + if (has_errors) { + // probably a regular string + printf("has errors\n"); + return calls; + } + + parseFunctionCalls(root_node, calls, source_code_cstr, 0); + + // Output the parsed calls + + ts_tree_delete(tree); + ts_parser_delete(parser); + printf("calls: %s\n", json(calls).dump().c_str()); + return calls; +} + diff --git a/examples/server/tree_sitter/libtree-sitter.a b/examples/server/tree_sitter/libtree-sitter.a new file mode 100644 index 000000000..330f5a405 Binary files /dev/null and b/examples/server/tree_sitter/libtree-sitter.a differ diff --git a/examples/server/tree_sitter/tree-sitter-python b/examples/server/tree_sitter/tree-sitter-python new file mode 160000 index 000000000..b8a4c6412 --- /dev/null +++ b/examples/server/tree_sitter/tree-sitter-python @@ -0,0 +1 @@ +Subproject commit b8a4c64121ba66b460cb878e934e3157ecbfb124 diff --git a/examples/server/tree_sitter/tree_sitter/api.h b/examples/server/tree_sitter/tree_sitter/api.h new file mode 100644 index 000000000..69c65fab6 --- /dev/null +++ b/examples/server/tree_sitter/tree_sitter/api.h @@ -0,0 +1,1262 @@ +#ifndef TREE_SITTER_API_H_ +#define TREE_SITTER_API_H_ + +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC visibility push(default) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include + +/****************************/ +/* Section - ABI Versioning */ +/****************************/ + +/** + * The latest ABI version that is supported by the current version of the + * library. When Languages are generated by the Tree-sitter CLI, they are + * assigned an ABI version number that corresponds to the current CLI version. + * The Tree-sitter library is generally backwards-compatible with languages + * generated using older CLI versions, but is not forwards-compatible. + */ +#define TREE_SITTER_LANGUAGE_VERSION 14 + +/** + * The earliest ABI version that is supported by the current version of the + * library. + */ +#define TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION 13 + +/*******************/ +/* Section - Types */ +/*******************/ + +typedef uint16_t TSStateId; +typedef uint16_t TSSymbol; +typedef uint16_t TSFieldId; +typedef struct TSLanguage TSLanguage; +typedef struct TSParser TSParser; +typedef struct TSTree TSTree; +typedef struct TSQuery TSQuery; +typedef struct TSQueryCursor TSQueryCursor; +typedef struct TSLookaheadIterator TSLookaheadIterator; + +typedef enum TSInputEncoding { + TSInputEncodingUTF8, + TSInputEncodingUTF16, +} TSInputEncoding; + +typedef enum TSSymbolType { + TSSymbolTypeRegular, + TSSymbolTypeAnonymous, + TSSymbolTypeAuxiliary, +} TSSymbolType; + +typedef struct TSPoint { + uint32_t row; + uint32_t column; +} TSPoint; + +typedef struct TSRange { + TSPoint start_point; + TSPoint end_point; + uint32_t start_byte; + uint32_t end_byte; +} TSRange; + +typedef struct TSInput { + void *payload; + const char *(*read)(void *payload, uint32_t byte_index, TSPoint position, uint32_t *bytes_read); + TSInputEncoding encoding; +} TSInput; + +typedef enum TSLogType { + TSLogTypeParse, + TSLogTypeLex, +} TSLogType; + +typedef struct TSLogger { + void *payload; + void (*log)(void *payload, TSLogType log_type, const char *buffer); +} TSLogger; + +typedef struct TSInputEdit { + uint32_t start_byte; + uint32_t old_end_byte; + uint32_t new_end_byte; + TSPoint start_point; + TSPoint old_end_point; + TSPoint new_end_point; +} TSInputEdit; + +typedef struct TSNode { + uint32_t context[4]; + const void *id; + const TSTree *tree; +} TSNode; + +typedef struct TSTreeCursor { + const void *tree; + const void *id; + uint32_t context[2]; +} TSTreeCursor; + +typedef struct TSQueryCapture { + TSNode node; + uint32_t index; +} TSQueryCapture; + +typedef enum TSQuantifier { + TSQuantifierZero = 0, // must match the array initialization value + TSQuantifierZeroOrOne, + TSQuantifierZeroOrMore, + TSQuantifierOne, + TSQuantifierOneOrMore, +} TSQuantifier; + +typedef struct TSQueryMatch { + uint32_t id; + uint16_t pattern_index; + uint16_t capture_count; + const TSQueryCapture *captures; +} TSQueryMatch; + +typedef enum TSQueryPredicateStepType { + TSQueryPredicateStepTypeDone, + TSQueryPredicateStepTypeCapture, + TSQueryPredicateStepTypeString, +} TSQueryPredicateStepType; + +typedef struct TSQueryPredicateStep { + TSQueryPredicateStepType type; + uint32_t value_id; +} TSQueryPredicateStep; + +typedef enum TSQueryError { + TSQueryErrorNone = 0, + TSQueryErrorSyntax, + TSQueryErrorNodeType, + TSQueryErrorField, + TSQueryErrorCapture, + TSQueryErrorStructure, + TSQueryErrorLanguage, +} TSQueryError; + +/********************/ +/* Section - Parser */ +/********************/ + +/** + * Create a new parser. + */ +TSParser *ts_parser_new(void); + +/** + * Delete the parser, freeing all of the memory that it used. + */ +void ts_parser_delete(TSParser *self); + +/** + * Get the parser's current language. + */ +const TSLanguage *ts_parser_language(const TSParser *self); + +/** + * Set the language that the parser should use for parsing. + * + * Returns a boolean indicating whether or not the language was successfully + * assigned. True means assignment succeeded. False means there was a version + * mismatch: the language was generated with an incompatible version of the + * Tree-sitter CLI. Check the language's version using [`ts_language_version`] + * and compare it to this library's [`TREE_SITTER_LANGUAGE_VERSION`] and + * [`TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION`] constants. + */ +bool ts_parser_set_language(TSParser *self, const TSLanguage *language); + +/** + * Set the ranges of text that the parser should include when parsing. + * + * By default, the parser will always include entire documents. This function + * allows you to parse only a *portion* of a document but still return a syntax + * tree whose ranges match up with the document as a whole. You can also pass + * multiple disjoint ranges. + * + * The second and third parameters specify the location and length of an array + * of ranges. The parser does *not* take ownership of these ranges; it copies + * the data, so it doesn't matter how these ranges are allocated. + * + * If `count` is zero, then the entire document will be parsed. Otherwise, + * the given ranges must be ordered from earliest to latest in the document, + * and they must not overlap. That is, the following must hold for all: + * + * `i < count - 1`: `ranges[i].end_byte <= ranges[i + 1].start_byte` + * + * If this requirement is not satisfied, the operation will fail, the ranges + * will not be assigned, and this function will return `false`. On success, + * this function returns `true` + */ +bool ts_parser_set_included_ranges( + TSParser *self, + const TSRange *ranges, + uint32_t count +); + +/** + * Get the ranges of text that the parser will include when parsing. + * + * The returned pointer is owned by the parser. The caller should not free it + * or write to it. The length of the array will be written to the given + * `count` pointer. + */ +const TSRange *ts_parser_included_ranges( + const TSParser *self, + uint32_t *count +); + +/** + * Use the parser to parse some source code and create a syntax tree. + * + * If you are parsing this document for the first time, pass `NULL` for the + * `old_tree` parameter. Otherwise, if you have already parsed an earlier + * version of this document and the document has since been edited, pass the + * previous syntax tree so that the unchanged parts of it can be reused. + * This will save time and memory. For this to work correctly, you must have + * already edited the old syntax tree using the [`ts_tree_edit`] function in a + * way that exactly matches the source code changes. + * + * The [`TSInput`] parameter lets you specify how to read the text. It has the + * following three fields: + * 1. [`read`]: A function to retrieve a chunk of text at a given byte offset + * and (row, column) position. The function should return a pointer to the + * text and write its length to the [`bytes_read`] pointer. The parser does + * not take ownership of this buffer; it just borrows it until it has + * finished reading it. The function should write a zero value to the + * [`bytes_read`] pointer to indicate the end of the document. + * 2. [`payload`]: An arbitrary pointer that will be passed to each invocation + * of the [`read`] function. + * 3. [`encoding`]: An indication of how the text is encoded. Either + * `TSInputEncodingUTF8` or `TSInputEncodingUTF16`. + * + * This function returns a syntax tree on success, and `NULL` on failure. There + * are three possible reasons for failure: + * 1. The parser does not have a language assigned. Check for this using the + [`ts_parser_language`] function. + * 2. Parsing was cancelled due to a timeout that was set by an earlier call to + * the [`ts_parser_set_timeout_micros`] function. You can resume parsing from + * where the parser left out by calling [`ts_parser_parse`] again with the + * same arguments. Or you can start parsing from scratch by first calling + * [`ts_parser_reset`]. + * 3. Parsing was cancelled using a cancellation flag that was set by an + * earlier call to [`ts_parser_set_cancellation_flag`]. You can resume parsing + * from where the parser left out by calling [`ts_parser_parse`] again with + * the same arguments. + * + * [`read`]: TSInput::read + * [`payload`]: TSInput::payload + * [`encoding`]: TSInput::encoding + * [`bytes_read`]: TSInput::read + */ +TSTree *ts_parser_parse( + TSParser *self, + const TSTree *old_tree, + TSInput input +); + +/** + * Use the parser to parse some source code stored in one contiguous buffer. + * The first two parameters are the same as in the [`ts_parser_parse`] function + * above. The second two parameters indicate the location of the buffer and its + * length in bytes. + */ +TSTree *ts_parser_parse_string( + TSParser *self, + const TSTree *old_tree, + const char *string, + uint32_t length +); + +/** + * Use the parser to parse some source code stored in one contiguous buffer with + * a given encoding. The first four parameters work the same as in the + * [`ts_parser_parse_string`] method above. The final parameter indicates whether + * the text is encoded as UTF8 or UTF16. + */ +TSTree *ts_parser_parse_string_encoding( + TSParser *self, + const TSTree *old_tree, + const char *string, + uint32_t length, + TSInputEncoding encoding +); + +/** + * Instruct the parser to start the next parse from the beginning. + * + * If the parser previously failed because of a timeout or a cancellation, then + * by default, it will resume where it left off on the next call to + * [`ts_parser_parse`] or other parsing functions. If you don't want to resume, + * and instead intend to use this parser to parse some other document, you must + * call [`ts_parser_reset`] first. + */ +void ts_parser_reset(TSParser *self); + +/** + * Set the maximum duration in microseconds that parsing should be allowed to + * take before halting. + * + * If parsing takes longer than this, it will halt early, returning NULL. + * See [`ts_parser_parse`] for more information. + */ +void ts_parser_set_timeout_micros(TSParser *self, uint64_t timeout_micros); + +/** + * Get the duration in microseconds that parsing is allowed to take. + */ +uint64_t ts_parser_timeout_micros(const TSParser *self); + +/** + * Set the parser's current cancellation flag pointer. + * + * If a non-null pointer is assigned, then the parser will periodically read + * from this pointer during parsing. If it reads a non-zero value, it will + * halt early, returning NULL. See [`ts_parser_parse`] for more information. + */ +void ts_parser_set_cancellation_flag(TSParser *self, const size_t *flag); + +/** + * Get the parser's current cancellation flag pointer. + */ +const size_t *ts_parser_cancellation_flag(const TSParser *self); + +/** + * Set the logger that a parser should use during parsing. + * + * The parser does not take ownership over the logger payload. If a logger was + * previously assigned, the caller is responsible for releasing any memory + * owned by the previous logger. + */ +void ts_parser_set_logger(TSParser *self, TSLogger logger); + +/** + * Get the parser's current logger. + */ +TSLogger ts_parser_logger(const TSParser *self); + +/** + * Set the file descriptor to which the parser should write debugging graphs + * during parsing. The graphs are formatted in the DOT language. You may want + * to pipe these graphs directly to a `dot(1)` process in order to generate + * SVG output. You can turn off this logging by passing a negative number. + */ +void ts_parser_print_dot_graphs(TSParser *self, int fd); + +/******************/ +/* Section - Tree */ +/******************/ + +/** + * Create a shallow copy of the syntax tree. This is very fast. + * + * You need to copy a syntax tree in order to use it on more than one thread at + * a time, as syntax trees are not thread safe. + */ +TSTree *ts_tree_copy(const TSTree *self); + +/** + * Delete the syntax tree, freeing all of the memory that it used. + */ +void ts_tree_delete(TSTree *self); + +/** + * Get the root node of the syntax tree. + */ +TSNode ts_tree_root_node(const TSTree *self); + +/** + * Get the root node of the syntax tree, but with its position + * shifted forward by the given offset. + */ +TSNode ts_tree_root_node_with_offset( + const TSTree *self, + uint32_t offset_bytes, + TSPoint offset_extent +); + +/** + * Get the language that was used to parse the syntax tree. + */ +const TSLanguage *ts_tree_language(const TSTree *self); + +/** + * Get the array of included ranges that was used to parse the syntax tree. + * + * The returned pointer must be freed by the caller. + */ +TSRange *ts_tree_included_ranges(const TSTree *self, uint32_t *length); + +/** + * Edit the syntax tree to keep it in sync with source code that has been + * edited. + * + * You must describe the edit both in terms of byte offsets and in terms of + * (row, column) coordinates. + */ +void ts_tree_edit(TSTree *self, const TSInputEdit *edit); + +/** + * Compare an old edited syntax tree to a new syntax tree representing the same + * document, returning an array of ranges whose syntactic structure has changed. + * + * For this to work correctly, the old syntax tree must have been edited such + * that its ranges match up to the new tree. Generally, you'll want to call + * this function right after calling one of the [`ts_parser_parse`] functions. + * You need to pass the old tree that was passed to parse, as well as the new + * tree that was returned from that function. + * + * The returned array is allocated using `malloc` and the caller is responsible + * for freeing it using `free`. The length of the array will be written to the + * given `length` pointer. + */ +TSRange *ts_tree_get_changed_ranges( + const TSTree *old_tree, + const TSTree *new_tree, + uint32_t *length +); + +/** + * Write a DOT graph describing the syntax tree to the given file. + */ +void ts_tree_print_dot_graph(const TSTree *self, int file_descriptor); + +/******************/ +/* Section - Node */ +/******************/ + +/** + * Get the node's type as a null-terminated string. + */ +const char *ts_node_type(TSNode self); + +/** + * Get the node's type as a numerical id. + */ +TSSymbol ts_node_symbol(TSNode self); + +/** + * Get the node's language. + */ +const TSLanguage *ts_node_language(TSNode self); + +/** + * Get the node's type as it appears in the grammar ignoring aliases as a + * null-terminated string. + */ +const char *ts_node_grammar_type(TSNode self); + +/** + * Get the node's type as a numerical id as it appears in the grammar ignoring + * aliases. This should be used in [`ts_language_next_state`] instead of + * [`ts_node_symbol`]. + */ +TSSymbol ts_node_grammar_symbol(TSNode self); + +/** + * Get the node's start byte. + */ +uint32_t ts_node_start_byte(TSNode self); + +/** + * Get the node's start position in terms of rows and columns. + */ +TSPoint ts_node_start_point(TSNode self); + +/** + * Get the node's end byte. + */ +uint32_t ts_node_end_byte(TSNode self); + +/** + * Get the node's end position in terms of rows and columns. + */ +TSPoint ts_node_end_point(TSNode self); + +/** + * Get an S-expression representing the node as a string. + * + * This string is allocated with `malloc` and the caller is responsible for + * freeing it using `free`. + */ +char *ts_node_string(TSNode self); + +/** + * Check if the node is null. Functions like [`ts_node_child`] and + * [`ts_node_next_sibling`] will return a null node to indicate that no such node + * was found. + */ +bool ts_node_is_null(TSNode self); + +/** + * Check if the node is *named*. Named nodes correspond to named rules in the + * grammar, whereas *anonymous* nodes correspond to string literals in the + * grammar. + */ +bool ts_node_is_named(TSNode self); + +/** + * Check if the node is *missing*. Missing nodes are inserted by the parser in + * order to recover from certain kinds of syntax errors. + */ +bool ts_node_is_missing(TSNode self); + +/** + * Check if the node is *extra*. Extra nodes represent things like comments, + * which are not required the grammar, but can appear anywhere. + */ +bool ts_node_is_extra(TSNode self); + +/** + * Check if a syntax node has been edited. + */ +bool ts_node_has_changes(TSNode self); + +/** + * Check if the node is a syntax error or contains any syntax errors. + */ +bool ts_node_has_error(TSNode self); + +/** + * Check if the node is a syntax error. +*/ +bool ts_node_is_error(TSNode self); + +/** + * Get this node's parse state. +*/ +TSStateId ts_node_parse_state(TSNode self); + +/** + * Get the parse state after this node. +*/ +TSStateId ts_node_next_parse_state(TSNode self); + +/** + * Get the node's immediate parent. + */ +TSNode ts_node_parent(TSNode self); + +/** + * Get the node's child at the given index, where zero represents the first + * child. + */ +TSNode ts_node_child(TSNode self, uint32_t child_index); + +/** + * Get the field name for node's child at the given index, where zero represents + * the first child. Returns NULL, if no field is found. + */ +const char *ts_node_field_name_for_child(TSNode self, uint32_t child_index); + +/** + * Get the node's number of children. + */ +uint32_t ts_node_child_count(TSNode self); + +/** + * Get the node's *named* child at the given index. + * + * See also [`ts_node_is_named`]. + */ +TSNode ts_node_named_child(TSNode self, uint32_t child_index); + +/** + * Get the node's number of *named* children. + * + * See also [`ts_node_is_named`]. + */ +uint32_t ts_node_named_child_count(TSNode self); + +/** + * Get the node's child with the given field name. + */ +TSNode ts_node_child_by_field_name( + TSNode self, + const char *name, + uint32_t name_length +); + +/** + * Get the node's child with the given numerical field id. + * + * You can convert a field name to an id using the + * [`ts_language_field_id_for_name`] function. + */ +TSNode ts_node_child_by_field_id(TSNode self, TSFieldId field_id); + +/** + * Get the node's next / previous sibling. + */ +TSNode ts_node_next_sibling(TSNode self); +TSNode ts_node_prev_sibling(TSNode self); + +/** + * Get the node's next / previous *named* sibling. + */ +TSNode ts_node_next_named_sibling(TSNode self); +TSNode ts_node_prev_named_sibling(TSNode self); + +/** + * Get the node's first child that extends beyond the given byte offset. + */ +TSNode ts_node_first_child_for_byte(TSNode self, uint32_t byte); + +/** + * Get the node's first named child that extends beyond the given byte offset. + */ +TSNode ts_node_first_named_child_for_byte(TSNode self, uint32_t byte); + +/** + * Get the node's number of descendants, including one for the node itself. + */ +uint32_t ts_node_descendant_count(TSNode self); + +/** + * Get the smallest node within this node that spans the given range of bytes + * or (row, column) positions. + */ +TSNode ts_node_descendant_for_byte_range(TSNode self, uint32_t start, uint32_t end); +TSNode ts_node_descendant_for_point_range(TSNode self, TSPoint start, TSPoint end); + +/** + * Get the smallest named node within this node that spans the given range of + * bytes or (row, column) positions. + */ +TSNode ts_node_named_descendant_for_byte_range(TSNode self, uint32_t start, uint32_t end); +TSNode ts_node_named_descendant_for_point_range(TSNode self, TSPoint start, TSPoint end); + +/** + * Edit the node to keep it in-sync with source code that has been edited. + * + * This function is only rarely needed. When you edit a syntax tree with the + * [`ts_tree_edit`] function, all of the nodes that you retrieve from the tree + * afterward will already reflect the edit. You only need to use [`ts_node_edit`] + * when you have a [`TSNode`] instance that you want to keep and continue to use + * after an edit. + */ +void ts_node_edit(TSNode *self, const TSInputEdit *edit); + +/** + * Check if two nodes are identical. + */ +bool ts_node_eq(TSNode self, TSNode other); + +/************************/ +/* Section - TreeCursor */ +/************************/ + +/** + * Create a new tree cursor starting from the given node. + * + * A tree cursor allows you to walk a syntax tree more efficiently than is + * possible using the [`TSNode`] functions. It is a mutable object that is always + * on a certain syntax node, and can be moved imperatively to different nodes. + */ +TSTreeCursor ts_tree_cursor_new(TSNode node); + +/** + * Delete a tree cursor, freeing all of the memory that it used. + */ +void ts_tree_cursor_delete(TSTreeCursor *self); + +/** + * Re-initialize a tree cursor to start at a different node. + */ +void ts_tree_cursor_reset(TSTreeCursor *self, TSNode node); + +/** + * Re-initialize a tree cursor to the same position as another cursor. + * + * Unlike [`ts_tree_cursor_reset`], this will not lose parent information and + * allows reusing already created cursors. +*/ +void ts_tree_cursor_reset_to(TSTreeCursor *dst, const TSTreeCursor *src); + +/** + * Get the tree cursor's current node. + */ +TSNode ts_tree_cursor_current_node(const TSTreeCursor *self); + +/** + * Get the field name of the tree cursor's current node. + * + * This returns `NULL` if the current node doesn't have a field. + * See also [`ts_node_child_by_field_name`]. + */ +const char *ts_tree_cursor_current_field_name(const TSTreeCursor *self); + +/** + * Get the field id of the tree cursor's current node. + * + * This returns zero if the current node doesn't have a field. + * See also [`ts_node_child_by_field_id`], [`ts_language_field_id_for_name`]. + */ +TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *self); + +/** + * Move the cursor to the parent of its current node. + * + * This returns `true` if the cursor successfully moved, and returns `false` + * if there was no parent node (the cursor was already on the root node). + */ +bool ts_tree_cursor_goto_parent(TSTreeCursor *self); + +/** + * Move the cursor to the next sibling of its current node. + * + * This returns `true` if the cursor successfully moved, and returns `false` + * if there was no next sibling node. + */ +bool ts_tree_cursor_goto_next_sibling(TSTreeCursor *self); + +/** + * Move the cursor to the previous sibling of its current node. + * + * This returns `true` if the cursor successfully moved, and returns `false` if + * there was no previous sibling node. + * + * Note, that this function may be slower than + * [`ts_tree_cursor_goto_next_sibling`] due to how node positions are stored. In + * the worst case, this will need to iterate through all the children upto the + * previous sibling node to recalculate its position. + */ +bool ts_tree_cursor_goto_previous_sibling(TSTreeCursor *self); + +/** + * Move the cursor to the first child of its current node. + * + * This returns `true` if the cursor successfully moved, and returns `false` + * if there were no children. + */ +bool ts_tree_cursor_goto_first_child(TSTreeCursor *self); + +/** + * Move the cursor to the last child of its current node. + * + * This returns `true` if the cursor successfully moved, and returns `false` if + * there were no children. + * + * Note that this function may be slower than [`ts_tree_cursor_goto_first_child`] + * because it needs to iterate through all the children to compute the child's + * position. + */ +bool ts_tree_cursor_goto_last_child(TSTreeCursor *self); + +/** + * Move the cursor to the node that is the nth descendant of + * the original node that the cursor was constructed with, where + * zero represents the original node itself. + */ +void ts_tree_cursor_goto_descendant(TSTreeCursor *self, uint32_t goal_descendant_index); + +/** + * Get the index of the cursor's current node out of all of the + * descendants of the original node that the cursor was constructed with. + */ +uint32_t ts_tree_cursor_current_descendant_index(const TSTreeCursor *self); + +/** + * Get the depth of the cursor's current node relative to the original + * node that the cursor was constructed with. + */ +uint32_t ts_tree_cursor_current_depth(const TSTreeCursor *self); + +/** + * Move the cursor to the first child of its current node that extends beyond + * the given byte offset or point. + * + * This returns the index of the child node if one was found, and returns -1 + * if no such child was found. + */ +int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *self, uint32_t goal_byte); +int64_t ts_tree_cursor_goto_first_child_for_point(TSTreeCursor *self, TSPoint goal_point); + +TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *cursor); + +/*******************/ +/* Section - Query */ +/*******************/ + +/** + * Create a new query from a string containing one or more S-expression + * patterns. The query is associated with a particular language, and can + * only be run on syntax nodes parsed with that language. + * + * If all of the given patterns are valid, this returns a [`TSQuery`]. + * If a pattern is invalid, this returns `NULL`, and provides two pieces + * of information about the problem: + * 1. The byte offset of the error is written to the `error_offset` parameter. + * 2. The type of error is written to the `error_type` parameter. + */ +TSQuery *ts_query_new( + const TSLanguage *language, + const char *source, + uint32_t source_len, + uint32_t *error_offset, + TSQueryError *error_type +); + +/** + * Delete a query, freeing all of the memory that it used. + */ +void ts_query_delete(TSQuery *self); + +/** + * Get the number of patterns, captures, or string literals in the query. + */ +uint32_t ts_query_pattern_count(const TSQuery *self); +uint32_t ts_query_capture_count(const TSQuery *self); +uint32_t ts_query_string_count(const TSQuery *self); + +/** + * Get the byte offset where the given pattern starts in the query's source. + * + * This can be useful when combining queries by concatenating their source + * code strings. + */ +uint32_t ts_query_start_byte_for_pattern(const TSQuery *self, uint32_t pattern_index); + +/** + * Get all of the predicates for the given pattern in the query. + * + * The predicates are represented as a single array of steps. There are three + * types of steps in this array, which correspond to the three legal values for + * the `type` field: + * - `TSQueryPredicateStepTypeCapture` - Steps with this type represent names + * of captures. Their `value_id` can be used with the + * [`ts_query_capture_name_for_id`] function to obtain the name of the capture. + * - `TSQueryPredicateStepTypeString` - Steps with this type represent literal + * strings. Their `value_id` can be used with the + * [`ts_query_string_value_for_id`] function to obtain their string value. + * - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels* + * that represent the end of an individual predicate. If a pattern has two + * predicates, then there will be two steps with this `type` in the array. + */ +const TSQueryPredicateStep *ts_query_predicates_for_pattern( + const TSQuery *self, + uint32_t pattern_index, + uint32_t *step_count +); + +/* + * Check if the given pattern in the query has a single root node. + */ +bool ts_query_is_pattern_rooted(const TSQuery *self, uint32_t pattern_index); + +/* + * Check if the given pattern in the query is 'non local'. + * + * A non-local pattern has multiple root nodes and can match within a + * repeating sequence of nodes, as specified by the grammar. Non-local + * patterns disable certain optimizations that would otherwise be possible + * when executing a query on a specific range of a syntax tree. + */ +bool ts_query_is_pattern_non_local(const TSQuery *self, uint32_t pattern_index); + +/* + * Check if a given pattern is guaranteed to match once a given step is reached. + * The step is specified by its byte offset in the query's source code. + */ +bool ts_query_is_pattern_guaranteed_at_step(const TSQuery *self, uint32_t byte_offset); + +/** + * Get the name and length of one of the query's captures, or one of the + * query's string literals. Each capture and string is associated with a + * numeric id based on the order that it appeared in the query's source. + */ +const char *ts_query_capture_name_for_id( + const TSQuery *self, + uint32_t index, + uint32_t *length +); + +/** + * Get the quantifier of the query's captures. Each capture is * associated + * with a numeric id based on the order that it appeared in the query's source. + */ +TSQuantifier ts_query_capture_quantifier_for_id( + const TSQuery *self, + uint32_t pattern_index, + uint32_t capture_index +); + +const char *ts_query_string_value_for_id( + const TSQuery *self, + uint32_t index, + uint32_t *length +); + +/** + * Disable a certain capture within a query. + * + * This prevents the capture from being returned in matches, and also avoids + * any resource usage associated with recording the capture. Currently, there + * is no way to undo this. + */ +void ts_query_disable_capture(TSQuery *self, const char *name, uint32_t length); + +/** + * Disable a certain pattern within a query. + * + * This prevents the pattern from matching and removes most of the overhead + * associated with the pattern. Currently, there is no way to undo this. + */ +void ts_query_disable_pattern(TSQuery *self, uint32_t pattern_index); + +/** + * Create a new cursor for executing a given query. + * + * The cursor stores the state that is needed to iteratively search + * for matches. To use the query cursor, first call [`ts_query_cursor_exec`] + * to start running a given query on a given syntax node. Then, there are + * two options for consuming the results of the query: + * 1. Repeatedly call [`ts_query_cursor_next_match`] to iterate over all of the + * *matches* in the order that they were found. Each match contains the + * index of the pattern that matched, and an array of captures. Because + * multiple patterns can match the same set of nodes, one match may contain + * captures that appear *before* some of the captures from a previous match. + * 2. Repeatedly call [`ts_query_cursor_next_capture`] to iterate over all of the + * individual *captures* in the order that they appear. This is useful if + * don't care about which pattern matched, and just want a single ordered + * sequence of captures. + * + * If you don't care about consuming all of the results, you can stop calling + * [`ts_query_cursor_next_match`] or [`ts_query_cursor_next_capture`] at any point. + * You can then start executing another query on another node by calling + * [`ts_query_cursor_exec`] again. + */ +TSQueryCursor *ts_query_cursor_new(void); + +/** + * Delete a query cursor, freeing all of the memory that it used. + */ +void ts_query_cursor_delete(TSQueryCursor *self); + +/** + * Start running a given query on a given node. + */ +void ts_query_cursor_exec(TSQueryCursor *self, const TSQuery *query, TSNode node); + +/** + * Manage the maximum number of in-progress matches allowed by this query + * cursor. + * + * Query cursors have an optional maximum capacity for storing lists of + * in-progress captures. If this capacity is exceeded, then the + * earliest-starting match will silently be dropped to make room for further + * matches. This maximum capacity is optional — by default, query cursors allow + * any number of pending matches, dynamically allocating new space for them as + * needed as the query is executed. + */ +bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *self); +uint32_t ts_query_cursor_match_limit(const TSQueryCursor *self); +void ts_query_cursor_set_match_limit(TSQueryCursor *self, uint32_t limit); + +/** + * Set the range of bytes or (row, column) positions in which the query + * will be executed. + */ +void ts_query_cursor_set_byte_range(TSQueryCursor *self, uint32_t start_byte, uint32_t end_byte); +void ts_query_cursor_set_point_range(TSQueryCursor *self, TSPoint start_point, TSPoint end_point); + +/** + * Advance to the next match of the currently running query. + * + * If there is a match, write it to `*match` and return `true`. + * Otherwise, return `false`. + */ +bool ts_query_cursor_next_match(TSQueryCursor *self, TSQueryMatch *match); +void ts_query_cursor_remove_match(TSQueryCursor *self, uint32_t match_id); + +/** + * Advance to the next capture of the currently running query. + * + * If there is a capture, write its match to `*match` and its index within + * the matche's capture list to `*capture_index`. Otherwise, return `false`. + */ +bool ts_query_cursor_next_capture( + TSQueryCursor *self, + TSQueryMatch *match, + uint32_t *capture_index +); + +/** + * Set the maximum start depth for a query cursor. + * + * This prevents cursors from exploring children nodes at a certain depth. + * Note if a pattern includes many children, then they will still be checked. + * + * The zero max start depth value can be used as a special behavior and + * it helps to destructure a subtree by staying on a node and using captures + * for interested parts. Note that the zero max start depth only limit a search + * depth for a pattern's root node but other nodes that are parts of the pattern + * may be searched at any depth what defined by the pattern structure. + * + * Set to `UINT32_MAX` to remove the maximum start depth. + */ +void ts_query_cursor_set_max_start_depth(TSQueryCursor *self, uint32_t max_start_depth); + +/**********************/ +/* Section - Language */ +/**********************/ + +/** + * Get another reference to the given language. + */ +const TSLanguage *ts_language_copy(const TSLanguage *self); + +/** + * Free any dynamically-allocated resources for this language, if + * this is the last reference. + */ +void ts_language_delete(const TSLanguage *self); + +/** + * Get the number of distinct node types in the language. + */ +uint32_t ts_language_symbol_count(const TSLanguage *self); + +/** + * Get the number of valid states in this language. +*/ +uint32_t ts_language_state_count(const TSLanguage *self); + +/** + * Get a node type string for the given numerical id. + */ +const char *ts_language_symbol_name(const TSLanguage *self, TSSymbol symbol); + +/** + * Get the numerical id for the given node type string. + */ +TSSymbol ts_language_symbol_for_name( + const TSLanguage *self, + const char *string, + uint32_t length, + bool is_named +); + +/** + * Get the number of distinct field names in the language. + */ +uint32_t ts_language_field_count(const TSLanguage *self); + +/** + * Get the field name string for the given numerical id. + */ +const char *ts_language_field_name_for_id(const TSLanguage *self, TSFieldId id); + +/** + * Get the numerical id for the given field name string. + */ +TSFieldId ts_language_field_id_for_name(const TSLanguage *self, const char *name, uint32_t name_length); + +/** + * Check whether the given node type id belongs to named nodes, anonymous nodes, + * or a hidden nodes. + * + * See also [`ts_node_is_named`]. Hidden nodes are never returned from the API. + */ +TSSymbolType ts_language_symbol_type(const TSLanguage *self, TSSymbol symbol); + +/** + * Get the ABI version number for this language. This version number is used + * to ensure that languages were generated by a compatible version of + * Tree-sitter. + * + * See also [`ts_parser_set_language`]. + */ +uint32_t ts_language_version(const TSLanguage *self); + +/** + * Get the next parse state. Combine this with lookahead iterators to generate + * completion suggestions or valid symbols in error nodes. Use + * [`ts_node_grammar_symbol`] for valid symbols. +*/ +TSStateId ts_language_next_state(const TSLanguage *self, TSStateId state, TSSymbol symbol); + +/********************************/ +/* Section - Lookahead Iterator */ +/********************************/ + +/** + * Create a new lookahead iterator for the given language and parse state. + * + * This returns `NULL` if state is invalid for the language. + * + * Repeatedly using [`ts_lookahead_iterator_next`] and + * [`ts_lookahead_iterator_current_symbol`] will generate valid symbols in the + * given parse state. Newly created lookahead iterators will contain the `ERROR` + * symbol. + * + * Lookahead iterators can be useful to generate suggestions and improve syntax + * error diagnostics. To get symbols valid in an ERROR node, use the lookahead + * iterator on its first leaf node state. For `MISSING` nodes, a lookahead + * iterator created on the previous non-extra leaf node may be appropriate. +*/ +TSLookaheadIterator *ts_lookahead_iterator_new(const TSLanguage *self, TSStateId state); + +/** + * Delete a lookahead iterator freeing all the memory used. +*/ +void ts_lookahead_iterator_delete(TSLookaheadIterator *self); + +/** + * Reset the lookahead iterator to another state. + * + * This returns `true` if the iterator was reset to the given state and `false` + * otherwise. +*/ +bool ts_lookahead_iterator_reset_state(TSLookaheadIterator *self, TSStateId state); + +/** + * Reset the lookahead iterator. + * + * This returns `true` if the language was set successfully and `false` + * otherwise. +*/ +bool ts_lookahead_iterator_reset(TSLookaheadIterator *self, const TSLanguage *language, TSStateId state); + +/** + * Get the current language of the lookahead iterator. +*/ +const TSLanguage *ts_lookahead_iterator_language(const TSLookaheadIterator *self); + +/** + * Advance the lookahead iterator to the next symbol. + * + * This returns `true` if there is a new symbol and `false` otherwise. +*/ +bool ts_lookahead_iterator_next(TSLookaheadIterator *self); + +/** + * Get the current symbol of the lookahead iterator; +*/ +TSSymbol ts_lookahead_iterator_current_symbol(const TSLookaheadIterator *self); + +/** + * Get the current symbol type of the lookahead iterator as a null terminated + * string. +*/ +const char *ts_lookahead_iterator_current_symbol_name(const TSLookaheadIterator *self); + +/*************************************/ +/* Section - WebAssembly Integration */ +/************************************/ + +typedef struct wasm_engine_t TSWasmEngine; +typedef struct TSWasmStore TSWasmStore; + +typedef enum { + TSWasmErrorKindNone = 0, + TSWasmErrorKindParse, + TSWasmErrorKindCompile, + TSWasmErrorKindInstantiate, + TSWasmErrorKindAllocate, +} TSWasmErrorKind; + +typedef struct { + TSWasmErrorKind kind; + char *message; +} TSWasmError; + +/** + * Create a Wasm store. + */ +TSWasmStore *ts_wasm_store_new( + TSWasmEngine *engine, + TSWasmError *error +); + +/** + * Free the memory associated with the given Wasm store. + */ +void ts_wasm_store_delete(TSWasmStore *); + +/** + * Create a language from a buffer of Wasm. The resulting language behaves + * like any other Tree-sitter language, except that in order to use it with + * a parser, that parser must have a Wasm store. Note that the language + * can be used with any Wasm store, it doesn't need to be the same store that + * was used to originally load it. + */ +const TSLanguage *ts_wasm_store_load_language( + TSWasmStore *, + const char *name, + const char *wasm, + uint32_t wasm_len, + TSWasmError *error +); + +/** + * Get the number of languages instantiated in the given wasm store. + */ +size_t ts_wasm_store_language_count(const TSWasmStore *); + +/** + * Check if the language came from a Wasm module. If so, then in order to use + * this language with a Parser, that parser must have a Wasm store assigned. + */ +bool ts_language_is_wasm(const TSLanguage *); + +/** + * Assign the given Wasm store to the parser. A parser must have a Wasm store + * in order to use Wasm languages. + */ +void ts_parser_set_wasm_store(TSParser *, TSWasmStore *); + +/** + * Remove the parser's current Wasm store and return it. This returns NULL if + * the parser doesn't have a Wasm store. + */ +TSWasmStore *ts_parser_take_wasm_store(TSParser *); + +/**********************************/ +/* Section - Global Configuration */ +/**********************************/ + +/** + * Set the allocation functions used by the library. + * + * By default, Tree-sitter uses the standard libc allocation functions, + * but aborts the process when an allocation fails. This function lets + * you supply alternative allocation functions at runtime. + * + * If you pass `NULL` for any parameter, Tree-sitter will switch back to + * its default implementation of that function. + * + * If you call this function after the library has already been used, then + * you must ensure that either: + * 1. All the existing objects have been freed. + * 2. The new allocator shares its state with the old one, so it is capable + * of freeing memory that was allocated by the old allocator. + */ +void ts_set_allocator( + void *(*new_malloc)(size_t), + void *(*new_calloc)(size_t, size_t), + void *(*new_realloc)(void *, size_t), + void (*new_free)(void *) +); + +#ifdef __cplusplus +} +#endif + +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC visibility pop +#endif + +#endif // TREE_SITTER_API_H_ diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c17dbcd84..fea29c33c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -4,6 +4,7 @@ #include "common.h" #include "json.hpp" +#include "python-parser.hpp" #include #include @@ -421,7 +422,7 @@ static std::string rubra_format_function_call_str(const std::vector & func return final_str; } -std::string default_tool_formatter(const std::vector& tools) { +static std::string default_tool_formatter(const std::vector& tools) { std::string toolText = ""; std::vector toolNames; for (const auto& tool : tools) { @@ -556,11 +557,6 @@ static json oaicompat_completion_params_parse( } -static json parse_response_for_function_call(const std::string content) { - -} - - static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); @@ -568,6 +564,9 @@ static json format_final_response_oaicompat(const json & request, json result, c int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); std::string content = json_value(result, "content", std::string("")); + std::vector parsed_content = parsePythonFunctionCalls(content); + + std::string finish_reason = "length"; if (stopped_word || stopped_eos) { finish_reason = "stop"; @@ -579,7 +578,7 @@ static json format_final_response_oaicompat(const json & request, json result, c {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, {"index", 0}, - {"message", json{{"content", content}, + {"message", json{{"content", parsed_content}, {"role", "assistant"}}}}}); std::time_t t = std::time(0); diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb index 15781f283..5302ac8b8 100644 --- a/test_llamacpp.ipynb +++ b/test_llamacpp.ipynb @@ -29,14 +29,14 @@ }, { "cell_type": "code", - "execution_count": 161, + "execution_count": 163, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The current weather in Boston is 72 degrees Fahrenheit and it's raining.\n" + " <>[orderUmbrella(brand_name=\"Patagonia\")]\n" ] } ], @@ -93,14 +93,15 @@ }, { "cell_type": "code", - "execution_count": 160, + "execution_count": 177, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The current stock price of Tesla (TSLA) is $170 and the current stock price of Google (GOOG) is $138.\n" + " <>[get_stock_fundermentals(symbol=\"TSLA\")]\n", + "<>[get_stock_fundermentals(symbol=\"GOOG\")]\n" ] } ], @@ -109,8 +110,8 @@ "functions = [\n", " {\"function\":\n", " {\n", - " \"name\": \"get_stock_price\",\n", - " \"description\": \"Get the current stock price\",\n", + " \"name\": \"get_stock_fundermentals\",\n", + " \"description\": \"Get the stock fundermentals data\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", @@ -147,14 +148,140 @@ " }}\n", "]\n", "\n", - "user_query = \"What's the stock price of Tesla and Google?\"\n", + "user_query = \"What's the stock fundementals of Tesla and google\"\n", "\n", "# \n", - "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[get_stock_price(symbol=\"TSLA\")], <>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<>170, 138\"}]\n", - "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n", - "res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", + "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[get_stock_price(symbol=\"TSLA\")], <>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n", + "res = get_mistral_rubra_response(user_query, \"mistral_rubra\", functions=functions, msgs=msgs)\n", "print(res.message.content)" ] + }, + { + "cell_type": "code", + "execution_count": 183, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n" + ] + } + ], + "source": [ + "import ast\n", + "\n", + "input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x=1, b='2', c=[1, 2, {'a': 1, 'b': 2}])]\"\n", + "\n", + "# Parse the string into an AST\n", + "parsed_ast = ast.parse(input_str, mode='eval')\n", + "\n", + "# Function to convert an AST node to a Python object\n", + "def ast_node_to_object(node):\n", + " if isinstance(node, ast.Constant):\n", + " return node.value\n", + " elif isinstance(node, ast.List):\n", + " return [ast_node_to_object(n) for n in node.elts]\n", + " elif isinstance(node, ast.Dict):\n", + " return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n", + " elif isinstance(node, ast.Tuple):\n", + " return tuple(ast_node_to_object(n) for n in node.elts)\n", + " # Add more cases here as needed\n", + " return None\n", + "\n", + "def find_calls(node):\n", + " calls = []\n", + " if isinstance(node, ast.Call): # If it's a function call\n", + " calls.append(node)\n", + " for child in ast.iter_child_nodes(node):\n", + " calls.extend(find_calls(child))\n", + " return calls\n", + "\n", + "# Extract all function call nodes\n", + "calls = find_calls(parsed_ast.body)\n", + "\n", + "functions = []\n", + "for call in calls:\n", + " if isinstance(call.func, ast.Name): # Ensure it's a named function\n", + " function_name = call.func.id\n", + " args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n", + " kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n", + " functions.append((function_name, args, kwargs))\n", + "\n", + "print(functions)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': ['func_nested(1, 2)', {'a': \"func_deep('value')\"}]})]\n" + ] + } + ], + "source": [ + "import ast\n", + "\n", + "input_str = \"[get_current_weather(location='Boston, MA', api_key=123456789, unit='fahrenheit'), func('cde', x=1, b='2', c=[func_nested(1, 2), {'a': func_deep('value')}])]\"\n", + "\n", + "def ast_node_to_object(node):\n", + " if isinstance(node, ast.Constant):\n", + " return node.value\n", + " elif isinstance(node, ast.List):\n", + " return [ast_node_to_object(n) for n in node.elts]\n", + " elif isinstance(node, ast.Dict):\n", + " return {ast_node_to_object(key): ast_node_to_object(value) for key, value in zip(node.keys, node.values)}\n", + " elif isinstance(node, ast.Tuple):\n", + " return tuple(ast_node_to_object(n) for n in node.elts)\n", + " elif isinstance(node, ast.Call):\n", + " return ast.unparse(node)\n", + " # Handle function calls: convert to a representation with the function name and arguments\n", + " # func_name = ast_node_to_object(node.func) # Get the function name\n", + " # args = [ast_node_to_object(arg) for arg in node.args] # Convert all positional arguments\n", + " # kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in node.keywords} # Convert all keyword arguments\n", + " # return {\"function\": func_name, \"args\": args, \"kwargs\": kwargs}\n", + " elif isinstance(node, ast.Name):\n", + " return node.id # Return the identifier name\n", + " # Add more cases here as needed\n", + " return None\n", + "\n", + "# Parse the string into an AST\n", + "parsed_ast = ast.parse(input_str, mode='eval')\n", + "\n", + "# Function to find only the top-level Call nodes\n", + "def find_top_level_calls(node):\n", + " calls = []\n", + " if isinstance(node, ast.Call): # If it's a function call\n", + " calls.append(node)\n", + " # Do not descend into child nodes to ensure we're only capturing top-level calls\n", + " return calls\n", + " for child in ast.iter_child_nodes(node):\n", + " # Recursively find calls without going into nested calls\n", + " calls.extend(find_top_level_calls(child))\n", + " return calls\n", + "\n", + "# Extract all top-level function call nodes\n", + "top_level_calls = find_top_level_calls(parsed_ast.body)\n", + "\n", + "# Process each call node to get the details you want\n", + "functions = []\n", + "for call in top_level_calls:\n", + " if isinstance(call.func, ast.Name): # Ensure it's a named function\n", + " function_name = call.func.id\n", + " args = [ast_node_to_object(arg) for arg in call.args] # Convert all positional arguments\n", + " kwargs = {kw.arg: ast_node_to_object(kw.value) for kw in call.keywords} # Convert all keyword arguments\n", + " functions.append((function_name, args, kwargs))\n", + "\n", + "print(functions)\n" + ] } ], "metadata": {