add tool-choice parameter
This commit is contained in:
parent
352f79ca9f
commit
becf9b4003
4 changed files with 103 additions and 31 deletions
|
@ -1993,15 +1993,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
std::back_inserter(params.chat_template));
|
std::back_inserter(params.chat_template));
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||||
|
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--tools"}, "JINJA_TOOLS",
|
{"--tools"}, "JINJA_TOOLS",
|
||||||
string_format(
|
"set to JSON array of tool definitions used for assistant function-calling (requires --jinja)",
|
||||||
"set to a JSON array of tool definitions used for assistant function-calling "
|
|
||||||
"(requires --jinja)"),
|
|
||||||
[](common_params ¶ms, const std::string & value) {
|
[](common_params ¶ms, const std::string & value) {
|
||||||
params.jinja_tools = value;
|
params.jinja_tools.tools(value);
|
||||||
}
|
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS"));
|
}).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||||
|
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--tool-choice"}, "JINJA_TOOL_CHOICE",
|
||||||
|
"set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")",
|
||||||
|
[](common_params ¶ms, const std::string & value) {
|
||||||
|
params.jinja_tools.choice(value);
|
||||||
|
|
||||||
|
}).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||||
|
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||||
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||||
|
|
|
@ -7,9 +7,6 @@
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
|
||||||
#include "json.hpp"
|
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "chat.hpp"
|
#include "chat.hpp"
|
||||||
|
@ -1772,6 +1769,42 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||||
// Chat template utils
|
// Chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
common_params_tools::common_params_tools(std::string tools, std::string choice) {
|
||||||
|
this->tools(tools);
|
||||||
|
this->choice(choice);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_params_tools::tools(std::string tools) {
|
||||||
|
try {
|
||||||
|
tools_ = std::make_shared<json>(json::parse(tools));
|
||||||
|
if (! tools_->is_array()) {
|
||||||
|
throw std::invalid_argument("tools must be a valid JSON array");
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const json::exception & err) {
|
||||||
|
throw std::invalid_argument(err.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_params_tools::choice(std::string choice) {
|
||||||
|
try {
|
||||||
|
if (choice == "auto" || choice == "required" || choice == "none") {
|
||||||
|
tool_choice_ = std::move(choice);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto choice_ptr = std::make_shared<json>(json::parse(choice));
|
||||||
|
tool_choice_ = choice_ptr;
|
||||||
|
if (! choice_ptr->is_object()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const json::exception & err) {
|
||||||
|
throw std::invalid_argument(err.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
try {
|
try {
|
||||||
|
@ -1798,7 +1831,7 @@ std::string common_chat_apply_template(
|
||||||
const std::vector<common_chat_msg> & msgs,
|
const std::vector<common_chat_msg> & msgs,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
std::string tools_json_arr)
|
const common_params_tools & tools)
|
||||||
{
|
{
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
common_chat_inputs inputs;
|
common_chat_inputs inputs;
|
||||||
|
@ -1807,17 +1840,19 @@ std::string common_chat_apply_template(
|
||||||
for (const auto & msg : msgs) {
|
for (const auto & msg : msgs) {
|
||||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||||
}
|
}
|
||||||
|
if (tools.tools() != nullptr) {
|
||||||
|
inputs.tools = *tools.tools();
|
||||||
|
}
|
||||||
|
auto choice = tools.choice();
|
||||||
|
if (std::holds_alternative<std::string>(choice)) {
|
||||||
|
inputs.tool_choice = std::get<std::string>(choice);
|
||||||
|
|
||||||
if (! tools_json_arr.empty()) {
|
} else {
|
||||||
try {
|
auto choice_ptr = std::get<common_params_tools::json_ptr>(choice);
|
||||||
inputs.tools = tools_json_arr;
|
if (choice_ptr != nullptr) {
|
||||||
|
inputs.tool_choice = *choice_ptr;
|
||||||
} catch (const json::exception & err) {
|
|
||||||
LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n",
|
|
||||||
tools_json_arr.c_str(), err.what());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs.messages = messages;
|
inputs.messages = messages;
|
||||||
inputs.add_generation_prompt = add_ass;
|
inputs.add_generation_prompt = add_ass;
|
||||||
return common_chat_params_init(tmpl, inputs).prompt;
|
return common_chat_params_init(tmpl, inputs).prompt;
|
||||||
|
@ -1858,11 +1893,11 @@ std::string common_chat_format_single(
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
std::string tools_json_arr)
|
const common_params_tools & tools)
|
||||||
{
|
{
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? ""
|
auto fmt_past_msg = past_msg.empty() ? ""
|
||||||
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr);
|
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools);
|
||||||
|
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
std::vector<common_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
|
@ -1871,7 +1906,7 @@ std::string common_chat_format_single(
|
||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr);
|
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
|
|
@ -8,6 +8,9 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
|
@ -202,6 +205,31 @@ struct common_params_vocoder {
|
||||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class common_params_tools {
|
||||||
|
public:
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
using json_ptr = std::shared_ptr<json>;
|
||||||
|
using tool_choice_t = std::variant<std::string, json_ptr>;
|
||||||
|
|
||||||
|
common_params_tools(std::string tools = "",
|
||||||
|
std::string choice = "auto");
|
||||||
|
|
||||||
|
common_params_tools(const common_params_tools & other) = default;
|
||||||
|
common_params_tools(common_params_tools && other) noexcept = default;
|
||||||
|
common_params_tools & operator=(const common_params_tools & other) = default;
|
||||||
|
common_params_tools & operator=(common_params_tools && other) noexcept = default;
|
||||||
|
|
||||||
|
void tools(std::string tools);
|
||||||
|
const json * tools() const { return tools_.get(); }
|
||||||
|
|
||||||
|
void choice(std::string choice);
|
||||||
|
const tool_choice_t & choice() const { return tool_choice_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
json_ptr tools_;
|
||||||
|
tool_choice_t tool_choice_;
|
||||||
|
};
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_ctx = 4096; // context size
|
int32_t n_ctx = 4096; // context size
|
||||||
|
@ -346,7 +374,8 @@ struct common_params {
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
bool use_jinja = false; // NOLINT
|
bool use_jinja = false; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
std::string jinja_tools = "";
|
common_params_tools jinja_tools;
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
std::string ssl_file_key = ""; // NOLINT
|
std::string ssl_file_key = ""; // NOLINT
|
||||||
|
@ -649,7 +678,7 @@ std::string common_chat_apply_template(
|
||||||
const std::vector<common_chat_msg> & chat,
|
const std::vector<common_chat_msg> & chat,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
std::string tools_json_arr = std::string());
|
const common_params_tools & tools = common_params_tools());
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
std::string common_chat_format_single(
|
std::string common_chat_format_single(
|
||||||
|
@ -658,7 +687,7 @@ std::string common_chat_format_single(
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
std::string tools_json_arr = std::string());
|
const common_params_tools & tools = common_params_tools());
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string common_chat_format_example(
|
std::string common_chat_format_example(
|
||||||
|
|
|
@ -263,14 +263,14 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role,
|
auto chat_add_and_format = [&chat_msgs, &chat_templates](
|
||||||
const std::string & content,
|
const std::string & role, const std::string & content,
|
||||||
const std::string & tools = std::string()) {
|
const common_params_tools & tools = common_params_tools())
|
||||||
|
{
|
||||||
common_chat_msg new_msg{role, content, {}};
|
common_chat_msg new_msg{role, content, {}};
|
||||||
auto formatted = common_chat_format_single(*chat_templates.template_default,
|
auto formatted = common_chat_format_single(
|
||||||
chat_msgs, new_msg,
|
*chat_templates.template_default, chat_msgs, new_msg, role == "user",
|
||||||
role == "user",
|
g_params->use_jinja, tools);
|
||||||
g_params->use_jinja, tools);
|
|
||||||
|
|
||||||
chat_msgs.push_back({role, content, {}});
|
chat_msgs.push_back({role, content, {}});
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue