From aeefbbb4a317810922b86ee818cf67601fe68c08 Mon Sep 17 00:00:00 2001 From: Yingbei Date: Tue, 12 Mar 2024 15:08:16 -0700 Subject: [PATCH] modify llama2 template formatting logic --- examples/server/utils.hpp | 7 ++ llama.cpp | 8 +- test_llamacpp.ipynb | 182 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 193 insertions(+), 4 deletions(-) create mode 100644 test_llamacpp.ipynb diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 38a9d3eaf..1c10063ee 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -523,6 +523,12 @@ static json oaicompat_completion_params_parse( return llama_params; } + +static json parse_response_for_function_call(const std::string content) { + +} + + static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); @@ -559,6 +565,7 @@ static json format_final_response_oaicompat(const json & request, json result, c }}, {"id", completion_id} }; + printf("format_final_response_oaicompat: %s\n", res.dump().c_str()); if (server_verbose) { res["__verbose"] = result; diff --git a/llama.cpp b/llama.cpp index ad7b7b7d4..02ca4832d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14059,7 +14059,7 @@ static int32_t llama_chat_apply_template_internal( bool strip_message = tmpl.find("content.strip()") != std::string::npos; // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning - ss << "[INST] "; + // ss << "[INST] "; for (auto message : chat) { std::string content = strip_message ? trim(message->content) : message->content; std::string role(message->role); @@ -14072,10 +14072,10 @@ static int32_t llama_chat_apply_template_internal( ss << "<>\n" << content << "\n<>\n\n"; } else { // if the model does not support system message, we still include it in the first message, but without <> - ss << content << "\n"; + ss << "" << content << "\n"; } - } else if (role == "user") { - ss << content << " [/INST]"; + } else if (role == "user" or role == "observation") { + ss << "[INST]" << content << " [/INST]"; } else { ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; is_inside_turn = false; diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb new file mode 100644 index 000000000..c33b9ea79 --- /dev/null +++ b/test_llamacpp.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "\n", + "def get_gorilla_response(prompt, model, functions, msgs):\n", + " openai.api_key = \"abc\"\n", + " openai.base_url = \"http://localhost:8019/v1/\"\n", + " \n", + " try:\n", + " completion = openai.chat.completions.create(\n", + " model=model,\n", + " temperature=0.1,\n", + " messages=msgs,\n", + " functions=functions,\n", + " function_call=\"auto\",\n", + " stream=False,\n", + " )\n", + " return completion.choices[0]\n", + " except Exception as e:\n", + " print(e, model, prompt)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " I'm sorry to hear that it's raining in Boston. If you need help ordering an umbrella, I can assist with that. Could you please tell me the brand name you prefer?\n" + ] + } + ], + "source": [ + "system_prompt = \"You are a helpful assistant.\"\n", + "functions = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"getCurrentWeather\",\n", + " \"description\": \"Get the weather in location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n", + " },\n", + " \"required\": [\"location\"]\n", + " }\n", + " }\n", + " },\n", + " { \"type\": \"function\",\n", + " \"function\":\n", + " {\n", + " \"name\": \"orderUmbrella\",\n", + " \"description\": \"Do this to help user to order an umbrella online\", \n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"brand_name\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The name of the umbrella brand\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"brand_name\"\n", + " ]\n", + " }\n", + " }},\n", + "]\n", + "\n", + "\n", + "user_query = \"What's the weather like in Boston? if it's raining, can you help me order an umbrella?\"\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>The current weather in Boston is 72 degrees f and rainy.\"}\n", + " ,\n", + " # {\"role\": \"assistant\", \"content\": \"The current weather in Boston, MA is 72 degrees Fahrenheit and rainy. It might be a good idea to carry an umbrella with you.\"},\n", + " # {\"role\": \"user\", \"content\": \"can you help me order an umbrella?\"},\n", + " ]\n", + "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", + "\n", + "res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", + "print(res.message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " <>[get_stock_price(symbol=\"TSLA\")]\n", + "<>[get_stock_price(symbol=\"GOOG\")]\n" + ] + } + ], + "source": [ + "system_prompt = \"You are a helpful assistant.\"\n", + "functions = [\n", + " {\"function\":\n", + " {\n", + " \"name\": \"get_stock_price\",\n", + " \"description\": \"Get the current stock price\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"symbol\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The stock symbol, e.g. AAPL, GOOG\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"symbol\"\n", + " ]\n", + " }\n", + " }},\n", + " {\"function\":{\n", + " \"name\": \"check_word_anagram\",\n", + " \"description\": \"Check if two words are anagrams of each other\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"word1\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The first word\"\n", + " },\n", + " \"word2\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The second word\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"word1\",\n", + " \"word2\"\n", + " ]\n", + " }\n", + " }}\n", + "]\n", + "\n", + "user_query = \"What's the stock price of Tesla and Google?\"\n", + "\n", + "# \n", + "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[get_stock_price(symbol=\"TSLA\")], <>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<>The current stock price of Tesla is $170, which has decreased 15 percent since 2024, and the current stock price of Google is $138, also not very good.\"}]\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n", + "res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", + "print(res.message.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}