Refactor string helpers into common

This commit is contained in:
Olivier Chafik 2025-01-22 02:08:18 +00:00
parent d77fecc3dc
commit 5268ec8947
5 changed files with 64 additions and 64 deletions

View file

@ -485,6 +485,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder); s = std::move(builder);
} }
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
result << separator;
}
result << values[i];
}
return result.str();
}
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
std::string string_repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
std::string string_from(bool value) { std::string string_from(bool value) {
return value ? "true" : "false"; return value ? "true" : "false";
} }

View file

@ -431,6 +431,10 @@ std::string string_format(const char * fmt, ...);
std::string string_strip(const std::string & str); std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp(); std::string string_get_sortable_timestamp();
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
std::string string_repeat(const std::string & str, size_t n);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace); void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
template<class T> template<class T>

View file

@ -1,4 +1,6 @@
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "common.h"
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <map> #include <map>
@ -11,8 +13,6 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
static std::string repeat(const std::string & str, size_t n);
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max(); auto has_max = max_items != std::numeric_limits<int>::max();
@ -125,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (sub_len > 0) { if (sub_len > 0) {
auto from_sub = from.substr(i + 1); auto from_sub = from.substr(i + 1);
auto to_sub = to.substr(i + 1); auto to_sub = to.substr(i + 1);
auto sub_zeros = repeat("0", sub_len); auto sub_zeros = string_repeat("0", sub_len);
auto sub_nines = repeat("9", sub_len); auto sub_nines = string_repeat("9", sub_len);
auto to_reached = false; auto to_reached = false;
out << "("; out << "(";
@ -185,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
auto max_digits = max_s.length(); auto max_digits = max_s.length();
for (auto digits = min_digits; digits < max_digits; digits++) { for (auto digits = min_digits; digits < max_digits; digits++) {
uniform_range(min_s, repeat("9", digits)); uniform_range(min_s, string_repeat("9", digits));
min_s = "1" + repeat("0", digits); min_s = "1" + string_repeat("0", digits);
out << " | "; out << " | ";
} }
uniform_range(min_s, max_s); uniform_range(min_s, max_s);
@ -315,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) { static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match; std::smatch match;
std::string result; std::string result;
@ -416,7 +373,7 @@ private:
for (size_t i = 0; i < alt_schemas.size(); i++) { for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
} }
return join(rules.begin(), rules.end(), " | "); return string_join(rules, " | ");
} }
std::string _visit_pattern(const std::string & pattern, const std::string & name) { std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@ -479,7 +436,7 @@ private:
for (const auto & item : ret) { for (const auto & item : ret) {
results.push_back(to_rule(item)); results.push_back(to_rule(item));
} }
return std::make_pair(join(results.begin(), results.end(), " "), false); return std::make_pair(string_join(results, " "), false);
}; };
while (i < length) { while (i < length) {
@ -537,7 +494,7 @@ private:
} }
curly_brackets += '}'; curly_brackets += '}';
i++; i++;
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0; int min_times = 0;
int max_times = std::numeric_limits<int>::max(); int max_times = std::numeric_limits<int>::max();
try { try {
@ -852,7 +809,7 @@ public:
return; return;
} }
std::string pointer = ref.substr(ref.find('#') + 1); std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/"); std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) { for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i]; std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) { if (target.is_null() || !target.contains(sel)) {
@ -903,7 +860,7 @@ public:
for (const auto & v : schema["enum"]) { for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v)); enum_values.push_back(_generate_constant_rule(v));
} }
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object") } else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") || && (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@ -1017,10 +974,10 @@ public:
void check_errors() { void check_errors() {
if (!_errors.empty()) { if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
} }
if (!_warnings.empty()) { if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
} }
} }

View file

@ -5,9 +5,6 @@
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
struct llama_grammar_builder { struct llama_grammar_builder {

View file

@ -631,7 +631,7 @@ common_tool_call_handler common_tool_call_handler_init(
handler.grammar_triggers.push_back("{\n \""); handler.grammar_triggers.push_back("{\n \"");
} }
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); builder.add_rule("root", string_join(tool_rules, " | "));
}); });
handler.additional_stops.push_back("<|eom_id|>"); handler.additional_stops.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, { handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
@ -658,9 +658,9 @@ common_tool_call_handler common_tool_call_handler_init(
handler.grammar_triggers.push_back("\n>>>" + name + "\n"); handler.grammar_triggers.push_back("\n>>>" + name + "\n");
} }
} }
auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space"; auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (parallel) { if (parallel) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space"; auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else { } else {
builder.add_rule("root", first_rule); builder.add_rule("root", first_rule);
@ -690,7 +690,7 @@ common_tool_call_handler common_tool_call_handler_init(
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space")); tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
} }
} }
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back("<function="); handler.grammar_triggers.push_back("<function=");
@ -721,7 +721,7 @@ common_tool_call_handler common_tool_call_handler_init(
})); }));
} }
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space"; auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back("<tool_call>"); handler.grammar_triggers.push_back("<tool_call>");