From 5268ec8947772fd270767a506273f3c29a593f6d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 02:08:18 +0000 Subject: [PATCH] Refactor string helpers into common --- common/common.cpp | 42 +++++++++++++++++++ common/common.h | 4 ++ common/json-schema-to-grammar.cpp | 69 ++++++------------------------- common/json-schema-to-grammar.h | 3 -- common/tool-call.cpp | 10 ++--- 5 files changed, 64 insertions(+), 64 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 046e236f2..76365c5c0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -485,6 +485,48 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector tokens; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + tokens.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + tokens.push_back(str.substr(start)); + + return tokens; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + std::string string_from(bool value) { return value ? "true" : "false"; } diff --git a/common/common.h b/common/common.h index 964ea0351..e33ba8a90 100644 --- a/common/common.h +++ b/common/common.h @@ -431,6 +431,10 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 677aca680..dacaa1fc3 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,8 +13,6 @@ using json = nlohmann::ordered_json; -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -125,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -185,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -315,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -416,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -479,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -537,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -852,7 +809,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -903,7 +860,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1017,10 +974,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 9a8b0f3ce..4f43ab3a5 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,9 +5,6 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -template -std::string join(Iterator begin, Iterator end, const std::string & separator); - std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); struct llama_grammar_builder { diff --git a/common/tool-call.cpp b/common/tool-call.cpp index a3ea52290..cf7d330f7 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -631,7 +631,7 @@ common_tool_call_handler common_tool_call_handler_init( handler.grammar_triggers.push_back("{\n \""); } - builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); + builder.add_rule("root", string_join(tool_rules, " | ")); }); handler.additional_stops.push_back("<|eom_id|>"); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, { @@ -658,9 +658,9 @@ common_tool_call_handler common_tool_call_handler_init( handler.grammar_triggers.push_back("\n>>>" + name + "\n"); } } - auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space"; + auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; if (parallel) { - auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space"; + auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); } else { builder.add_rule("root", first_rule); @@ -690,7 +690,7 @@ common_tool_call_handler common_tool_call_handler_init( tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); } } - auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_triggers.push_back("\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; + auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_triggers.push_back("");