modify llama2 template formatting logic

This commit is contained in:
Yingbei 2024-03-12 15:08:16 -07:00
parent 8902fd41f0
commit aeefbbb4a3
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
3 changed files with 193 additions and 4 deletions

View file

@ -523,6 +523,12 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json parse_response_for_function_call(const std::string content) {
}
static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) { static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
@ -559,6 +565,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
}}, }},
{"id", completion_id} {"id", completion_id}
}; };
printf("format_final_response_oaicompat: %s\n", res.dump().c_str());
if (server_verbose) { if (server_verbose) {
res["__verbose"] = result; res["__verbose"] = result;

View file

@ -14059,7 +14059,7 @@ static int32_t llama_chat_apply_template_internal(
bool strip_message = tmpl.find("content.strip()") != std::string::npos; bool strip_message = tmpl.find("content.strip()") != std::string::npos;
// construct the prompt // construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] "; // ss << "[INST] ";
for (auto message : chat) { for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content; std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role); std::string role(message->role);
@ -14072,10 +14072,10 @@ static int32_t llama_chat_apply_template_internal(
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n"; ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
} else { } else {
// if the model does not support system message, we still include it in the first message, but without <<SYS>> // if the model does not support system message, we still include it in the first message, but without <<SYS>>
ss << content << "\n"; ss << "<s>" << content << "\n";
} }
} else if (role == "user") { } else if (role == "user" or role == "observation") {
ss << content << " [/INST]"; ss << "[INST]" << content << " [/INST]";
} else { } else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>"; ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
is_inside_turn = false; is_inside_turn = false;

182
test_llamacpp.ipynb Normal file
View file

@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"\n",
"def get_gorilla_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"abc\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n",
" try:\n",
" completion = openai.chat.completions.create(\n",
" model=model,\n",
" temperature=0.1,\n",
" messages=msgs,\n",
" functions=functions,\n",
" function_call=\"auto\",\n",
" stream=False,\n",
" )\n",
" return completion.choices[0]\n",
" except Exception as e:\n",
" print(e, model, prompt)\n"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" I'm sorry to hear that it's raining in Boston. If you need help ordering an umbrella, I can assist with that. Could you please tell me the brand name you prefer?\n"
]
}
],
"source": [
"system_prompt = \"You are a helpful assistant.\"\n",
"functions = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"getCurrentWeather\",\n",
" \"description\": \"Get the weather in location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" }\n",
" },\n",
" { \"type\": \"function\",\n",
" \"function\":\n",
" {\n",
" \"name\": \"orderUmbrella\",\n",
" \"description\": \"Do this to help user to order an umbrella online\", \n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"brand_name\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The name of the umbrella brand\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"brand_name\"\n",
" ]\n",
" }\n",
" }},\n",
"]\n",
"\n",
"\n",
"user_query = \"What's the weather like in Boston? if it's raining, can you help me order an umbrella?\"\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>The current weather in Boston is 72 degrees f and rainy.\"}\n",
" ,\n",
" # {\"role\": \"assistant\", \"content\": \"The current weather in Boston, MA is 72 degrees Fahrenheit and rainy. It might be a good idea to carry an umbrella with you.\"},\n",
" # {\"role\": \"user\", \"content\": \"can you help me order an umbrella?\"},\n",
" ]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
"\n",
"res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"print(res.message.content)"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" <<functions>>[get_stock_price(symbol=\"TSLA\")]\n",
"<<functions>>[get_stock_price(symbol=\"GOOG\")]\n"
]
}
],
"source": [
"system_prompt = \"You are a helpful assistant.\"\n",
"functions = [\n",
" {\"function\":\n",
" {\n",
" \"name\": \"get_stock_price\",\n",
" \"description\": \"Get the current stock price\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"symbol\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The stock symbol, e.g. AAPL, GOOG\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"symbol\"\n",
" ]\n",
" }\n",
" }},\n",
" {\"function\":{\n",
" \"name\": \"check_word_anagram\",\n",
" \"description\": \"Check if two words are anagrams of each other\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"word1\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The first word\"\n",
" },\n",
" \"word2\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The second word\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"word1\",\n",
" \"word2\"\n",
" ]\n",
" }\n",
" }}\n",
"]\n",
"\n",
"user_query = \"What's the stock price of Tesla and Google?\"\n",
"\n",
"# \n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>The current stock price of Tesla is $170, which has decreased 15 percent since 2024, and the current stock price of Google is $138, also not very good.\"}]\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
"res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"print(res.message.content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}