update notebook for easier comparison

This commit is contained in:
Yingbei 2024-03-26 15:11:49 -07:00
parent eaec0b8748
commit 711fda99c6
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD

View file

@ -2,14 +2,69 @@
"cells": [
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def run_completion(chat_method, user_query):\n",
" print(chat_method)\n",
"import json\n",
"import uuid\n",
"from functools import partial\n",
"\n",
"def get_oai_response(model, functions, msgs, api_key, base_url):\n",
" import openai\n",
" openai.api_key = api_key ## Add your API key here\n",
" openai.base_url = base_url\n",
" print(f\"Pointing to URL: {base_url}\")\n",
" \n",
" try:\n",
" completion = openai.chat.completions.create(\n",
" model=model,\n",
" temperature=0.1,\n",
" messages=msgs,\n",
" tools=functions,\n",
" tool_choice=\"auto\",\n",
" # functions=functions,\n",
" # function_call=\"auto\",\n",
" stream=False,\n",
" )\n",
" return completion.choices[0]\n",
" except Exception as e:\n",
" print(e)\n",
"\n",
"\n",
"def insert_tool_response(res, msgs):\n",
" # for tool_call in res.message.tool_calls:\n",
" # print(f\"Tool Call: {tool_call.id}, {tool_call.function}\")\n",
" assistant_message = res.message\n",
" tool_calls = []\n",
" for tool_call in assistant_message.tool_calls:\n",
" tool_calls.append( {\n",
" \"id\": tool_call.id,\n",
" \"function\": {\"name\": tool_call.function.name,\n",
" \"arguments\": tool_call.function.arguments},\n",
" \"type\": \"function\",\n",
" })\n",
" msgs.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n",
" \n",
" for i, tool_call in enumerate(assistant_message.tool_calls):\n",
" if tool_call.function.name == \"getCurrentWeather\":\n",
" print()\n",
" l = len((json.loads(assistant_message.tool_calls[i].function.arguments))[\"location\"])\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"temprature is {(i+1) * 50 + l } degree\"})\n",
" elif tool_call.function.name == \"calculate_distance\":\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Distance is {(i+1) * 1700} miles.\"})\n",
" elif tool_call.function.name == \"generate_password\":\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Password generated: {uuid.uuid4().hex[:8]}\"})\n",
" else:\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Order placed. the price is {(i+1) * 10} dollars.\"})\n",
" \n",
" return msgs\n",
"\n",
"def run_completion(chat_method, user_query, msgs=[]):\n",
" system_prompt = \"You are a helpful assistant.\"\n",
" functions = [\n",
" # {\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}},\n",
" \n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
@ -43,102 +98,36 @@
" ]\n",
" }\n",
" }},\n",
" ]\n",
" {\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}}\n",
" ]\n",
" # functions = [{\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}}]\n",
"\n",
" msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
" if not msgs or len(msgs) == 0:\n",
" msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
" else:\n",
" msgs.append({\"role\": \"user\", \"content\": user_query})\n",
"\n",
" res = chat_method(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
" print(f\"First Response:\")\n",
" for tool_call in res.message.tool_calls:\n",
" print(f\"Tool Call: {tool_call.id}, {tool_call.function}\")\n",
" assistant_message = res.message\n",
" tool_calls = []\n",
" for tool_call in assistant_message.tool_calls:\n",
" tool_calls.append( {\n",
" \"id\": tool_call.id,\n",
" \"function\": {\"name\": tool_call.function.name,\n",
" \"arguments\": tool_call.function.arguments},\n",
" \"type\": \"function\",\n",
" })\n",
" msgs.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n",
" \n",
" for i, tool_call in enumerate(assistant_message.tool_calls):\n",
" if tool_call.function.name == \"getCurrentWeather\":\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"temprature is {i * 50} degree\"})\n",
" res = chat_method(model=\"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
" res_next = res\n",
" if res_next.message.content and len(res_next.message.content) > 0:\n",
" print(\"\\n[AI response]:\\n\", res_next.message.content)\n",
" else:\n",
" print(\"\\n[AI calling functions]:\")\n",
" for tool_call in res_next.message.tool_calls:\n",
" print(f\"Tool Call: {tool_call.function}\")\n",
" while res_next.message.tool_calls and len(res_next.message.tool_calls) > 0:\n",
" msgs = insert_tool_response(res_next, msgs)\n",
"\n",
" res_next = chat_method(model=\"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
" # for m in msgs:\n",
" # print(m)\n",
" if res_next.message.content and len(res_next.message.content) > 0:\n",
" print(\"\\n[AI response]:\\n\", res_next.message.content)\n",
" else:\n",
" msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Order placed.\"})\n",
" \n",
"\n",
" print(\"Print before second response...\")\n",
" res_next = chat_method(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
" for m in msgs:\n",
" print(m)\n",
" print(f\"Second Response: {res_next.message}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"\n",
"def get_oai_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"sk-\" ## Add your API key here\n",
" openai.base_url = \"https://api.openai.com/v1/\"\n",
" \n",
" try:\n",
" completion = openai.chat.completions.create(\n",
" model=model,\n",
" temperature=0.1,\n",
" messages=msgs,\n",
" tools=functions,\n",
" tool_choice=\"auto\",\n",
" # functions=functions,\n",
" # function_call=\"auto\",\n",
" stream=False,\n",
" )\n",
" return completion.choices[0]\n",
" except Exception as e:\n",
" print(e, model, prompt)\n"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<function get_oai_response at 0x10b9c93f0>\n",
"First Response:\n",
"Tool Call: call_FYLEpX5CVo2dqSyNupcgtFak, Function(arguments='{\"number_to_buy\":2}', name='orderUmbrella')\n",
"Print before second response...\n",
"{'role': 'system', 'content': 'You are a helpful assistant.'}\n",
"{'role': 'user', 'content': 'order 2 umbrellas'}\n",
"{'role': 'assistant', 'tool_calls': [{'id': 'call_FYLEpX5CVo2dqSyNupcgtFak', 'function': {'name': 'orderUmbrella', 'arguments': '{\"number_to_buy\":2}'}, 'type': 'function'}]}\n",
"{'role': 'tool', 'tool_call_id': 'call_FYLEpX5CVo2dqSyNupcgtFak', 'name': 'orderUmbrella', 'content': 'Order placed.'}\n",
"Second Response: ChatCompletionMessage(content=\"I've placed the order for 2 umbrellas for you. Is there anything else I can help with?\", role='assistant', function_call=None, tool_calls=None)\n"
]
}
],
"source": [
"# user_query = \"What is the distance between San Francisco and Cupertino by car and by air\"\n",
"# user_query = \"weather in boston as well as cupertino?\"\n",
"user_query = \"order 2 umbrellas\"\n",
"run_completion(get_oai_response, user_query)"
" print(\"\\n[AI calling functions]:\")\n",
" for tool_call in res_next.message.tool_calls:\n",
" print(f\"Tool Call: {tool_call.function}\")\n",
" "
]
},
{
@ -150,58 +139,102 @@
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"\n",
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"sk-\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n",
" try:\n",
" completion = openai.chat.completions.create(\n",
" model=model,\n",
" temperature=0.1,\n",
" messages=msgs,\n",
" tools=functions,\n",
" tool_choice=\"auto\",\n",
" stream=False,\n",
" )\n",
" return completion.choices[0]\n",
" except Exception as e:\n",
" print(e, model, prompt)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<function get_mistral_rubra_response at 0x10b9cac20>\n",
"First Response:\n",
"Tool Call: 0, Function(arguments='{\"number_to_buy\":\"2\"}', name='orderUmbrella')\n",
"Print before second response...\n",
"{'role': 'system', 'content': 'You are a helpful assistant.'}\n",
"{'role': 'user', 'content': 'order 2 umbrellas'}\n",
"{'role': 'assistant', 'tool_calls': [{'id': '0', 'function': {'name': 'orderUmbrella', 'arguments': '{\"number_to_buy\":\"2\"}'}, 'type': 'function'}]}\n",
"{'role': 'tool', 'tool_call_id': '0', 'name': 'orderUmbrella', 'content': 'Order placed.'}\n",
"Second Response: ChatCompletionMessage(content=' Your order for 2 umbrellas has been placed.', role='assistant', function_call=None, tool_calls=None)\n"
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Boston, MA\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Cupertino, CA\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Los Angeles, CA\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"New York City, NY\"}', name='getCurrentWeather')\n",
"\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI response]:\n",
" The distance from Boston, MA to Cupertino, CA is approximately 2,800 miles. The distance from Los Angeles, CA to New York City, NY is approximately 2,800 miles as well.\n"
]
}
],
"source": [
"# user_query = \"generate a password of length 10 and another of length 20\" \n",
"# user_query = \"what's the weather in Boston and Cupertino?\"\n",
"user_query = \"order 2 umbrellas\"\n",
"run_completion(get_mistral_rubra_response, user_query)"
"import openai\n",
"\n",
"\n",
"local_api_key = \"sk-\"\n",
"local_base_url = \"http://localhost:8019/v1/\"\n",
"get_mistral_rubra_response = partial(get_oai_response, api_key=local_api_key, base_url=local_base_url)\n",
"\n",
"user_query = \"calculate the distance from boston to cupertino? and distance from LA to NYC\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query)\n",
"# user_query = \"what's the weather in Boston and Cupertino and Chicago?\"\n",
"# # user_query = \"order 2 umbrellas\"\n",
"# msgs = run_completion(get_mistral_rubra_response, user_query)\n",
"# user_query2 = \"now order 3 umbrellas for me\"\n",
"# msgs = run_completion(get_mistral_rubra_response, user_query2, msgs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pointing to URL: https://api.openai.com/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\": \"Boston, MA\", \"unit\": \"f\"}', name='getCurrentWeather')\n",
"Tool Call: Function(arguments='{\"location\": \"Cupertino, CA\", \"unit\": \"f\"}', name='getCurrentWeather')\n",
"Tool Call: Function(arguments='{\"origin\": \"Boston, MA\", \"destination\": \"Cupertino, CA\", \"mode\": \"driving\"}', name='calculate_distance')\n",
"\n",
"\n",
"Pointing to URL: https://api.openai.com/v1/\n",
"\n",
"[AI response]:\n",
" The current weather in Boston, MA is 60°F, and in Cupertino, CA, it is 113°F. The distance from Boston, MA to Cupertino, CA is approximately 5100 miles when traveling by driving.\n"
]
}
],
"source": [
"import openai\n",
"\n",
"\n",
"\n",
"\n",
"openai_api_key = \"sk-\"\n",
"openai_base_url = \"https://api.openai.com/v1/\"\n",
"get_openai_response = partial(get_oai_response, api_key=openai_api_key, base_url=openai_base_url)\n",
"\n",
"# oai_user_query = \"What is the distance between San Francisco and Cupertino by car and by air\"\n",
"oai_user_query = \"weather in boston as well as cupertino? and calculate the distance from boston to cupertino\"\n",
"# user_query = \"order 2 umbrellas\"\n",
"run_completion(get_openai_response, oai_user_query)"
]
},
{
@ -213,14 +246,18 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"None\n"
"ename": "TypeError",
"evalue": "get_oai_response() got multiple values for argument 'functions'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 48\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;66;03m# \u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"{'symbol': 'TSLA', 'company_name': 'Tesla, Inc.', 'sector': 'Consumer Cyclical', 'industry': 'Auto Manufacturers', 'market_cap': 611384164352, 'pe_ratio': 49.604652, 'pb_ratio': 9.762013, 'dividend_yield': None, 'eps': 4.3, 'beta': 2.427, '52_week_high': 299.29, '52_week_low': 152.37}}\"}]\u001b[39;00m\n\u001b[1;32m 47\u001b[0m msgs \u001b[38;5;241m=\u001b[39m [{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msystem\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m:system_prompt} ,{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: user_query},]\n\u001b[0;32m---> 48\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mget_mistral_rubra_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43muser_query\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmistral_rubra\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunctions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunctions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsgs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmsgs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28mprint\u001b[39m(res\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;241m.\u001b[39mcontent)\n",
"\u001b[0;31mTypeError\u001b[0m: get_oai_response() got multiple values for argument 'functions'"
]
}
],
@ -278,7 +315,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": null,
"metadata": {},
"outputs": [
{