diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 38a9d3eaf..1c10063ee 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -523,6 +523,12 @@ static json oaicompat_completion_params_parse(
return llama_params;
}
+
+static json parse_response_for_function_call(const std::string content) {
+
+}
+
+
static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
@@ -559,6 +565,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
}},
{"id", completion_id}
};
+ printf("format_final_response_oaicompat: %s\n", res.dump().c_str());
if (server_verbose) {
res["__verbose"] = result;
diff --git a/llama.cpp b/llama.cpp
index ad7b7b7d4..02ca4832d 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -14059,7 +14059,7 @@ static int32_t llama_chat_apply_template_internal(
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
- ss << "[INST] ";
+ // ss << "[INST] ";
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
@@ -14072,10 +14072,10 @@ static int32_t llama_chat_apply_template_internal(
ss << "<>\n" << content << "\n<>\n\n";
} else {
// if the model does not support system message, we still include it in the first message, but without <>
- ss << content << "\n";
+ ss << "" << content << "\n";
}
- } else if (role == "user") {
- ss << content << " [/INST]";
+ } else if (role == "user" or role == "observation") {
+ ss << "[INST]" << content << " [/INST]";
} else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "";
is_inside_turn = false;
diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb
new file mode 100644
index 000000000..c33b9ea79
--- /dev/null
+++ b/test_llamacpp.ipynb
@@ -0,0 +1,182 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import openai\n",
+ "\n",
+ "\n",
+ "def get_gorilla_response(prompt, model, functions, msgs):\n",
+ " openai.api_key = \"abc\"\n",
+ " openai.base_url = \"http://localhost:8019/v1/\"\n",
+ " \n",
+ " try:\n",
+ " completion = openai.chat.completions.create(\n",
+ " model=model,\n",
+ " temperature=0.1,\n",
+ " messages=msgs,\n",
+ " functions=functions,\n",
+ " function_call=\"auto\",\n",
+ " stream=False,\n",
+ " )\n",
+ " return completion.choices[0]\n",
+ " except Exception as e:\n",
+ " print(e, model, prompt)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 136,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " I'm sorry to hear that it's raining in Boston. If you need help ordering an umbrella, I can assist with that. Could you please tell me the brand name you prefer?\n"
+ ]
+ }
+ ],
+ "source": [
+ "system_prompt = \"You are a helpful assistant.\"\n",
+ "functions = [\n",
+ " {\n",
+ " \"type\": \"function\",\n",
+ " \"function\": {\n",
+ " \"name\": \"getCurrentWeather\",\n",
+ " \"description\": \"Get the weather in location\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n",
+ " \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n",
+ " },\n",
+ " \"required\": [\"location\"]\n",
+ " }\n",
+ " }\n",
+ " },\n",
+ " { \"type\": \"function\",\n",
+ " \"function\":\n",
+ " {\n",
+ " \"name\": \"orderUmbrella\",\n",
+ " \"description\": \"Do this to help user to order an umbrella online\", \n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"brand_name\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The name of the umbrella brand\"\n",
+ " }\n",
+ " },\n",
+ " \"required\": [\n",
+ " \"brand_name\"\n",
+ " ]\n",
+ " }\n",
+ " }},\n",
+ "]\n",
+ "\n",
+ "\n",
+ "user_query = \"What's the weather like in Boston? if it's raining, can you help me order an umbrella?\"\n",
+ "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>The current weather in Boston is 72 degrees f and rainy.\"}\n",
+ " ,\n",
+ " # {\"role\": \"assistant\", \"content\": \"The current weather in Boston, MA is 72 degrees Fahrenheit and rainy. It might be a good idea to carry an umbrella with you.\"},\n",
+ " # {\"role\": \"user\", \"content\": \"can you help me order an umbrella?\"},\n",
+ " ]\n",
+ "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
+ "\n",
+ "res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
+ "print(res.message.content)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 134,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " <>[get_stock_price(symbol=\"TSLA\")]\n",
+ "<>[get_stock_price(symbol=\"GOOG\")]\n"
+ ]
+ }
+ ],
+ "source": [
+ "system_prompt = \"You are a helpful assistant.\"\n",
+ "functions = [\n",
+ " {\"function\":\n",
+ " {\n",
+ " \"name\": \"get_stock_price\",\n",
+ " \"description\": \"Get the current stock price\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"symbol\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The stock symbol, e.g. AAPL, GOOG\"\n",
+ " }\n",
+ " },\n",
+ " \"required\": [\n",
+ " \"symbol\"\n",
+ " ]\n",
+ " }\n",
+ " }},\n",
+ " {\"function\":{\n",
+ " \"name\": \"check_word_anagram\",\n",
+ " \"description\": \"Check if two words are anagrams of each other\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"word1\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The first word\"\n",
+ " },\n",
+ " \"word2\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The second word\"\n",
+ " }\n",
+ " },\n",
+ " \"required\": [\n",
+ " \"word1\",\n",
+ " \"word2\"\n",
+ " ]\n",
+ " }\n",
+ " }}\n",
+ "]\n",
+ "\n",
+ "user_query = \"What's the stock price of Tesla and Google?\"\n",
+ "\n",
+ "# \n",
+ "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[get_stock_price(symbol=\"TSLA\")], <>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<>The current stock price of Tesla is $170, which has decreased 15 percent since 2024, and the current stock price of Google is $138, also not very good.\"}]\n",
+ "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
+ "res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
+ "print(res.message.content)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "py310",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}