Spaces:
Sleeping
Sleeping
File size: 7,902 Bytes
5fdb69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
{
"cells": [
{
"cell_type": "markdown",
"id": "ddfa9ae6-69fe-444a-b994-8c4c5970a7ec",
"metadata": {},
"source": [
"# Project - Airline AI Assistant\n",
"\n",
"We'll now bring together what we've learned to make an AI Customer Support assistant for an Airline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import json\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae",
"metadata": {},
"outputs": [],
"source": [
"# Initialization\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"MODEL = \"gpt-4o-mini\"\n",
"openai = OpenAI()\n",
"\n",
"# As an alternative, if you'd like to use Ollama instead of OpenAI\n",
"# Check that Ollama is running for you locally (see week1/day2 exercise) then uncomment these next 2 lines\n",
"# MODEL = \"llama3.2\"\n",
"# openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a521d84-d07c-49ab-a0df-d6451499ed97",
"metadata": {},
"outputs": [],
"source": [
"system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n",
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
"system_message += \"Always be accurate. If you don't know the answer, say so.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61a2a15d-b559-4844-b377-6bd5cb4949f6",
"metadata": {},
"outputs": [],
"source": [
"# This function looks rather simpler than the one from my video, because we're taking advantage of the latest Gradio updates\n",
"\n",
"def chat(message, history):\n",
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
" return response.choices[0].message.content\n",
"\n",
"gr.ChatInterface(fn=chat, type=\"messages\").launch()"
]
},
{
"cell_type": "markdown",
"id": "36bedabf-a0a7-4985-ad8e-07ed6a55a3a4",
"metadata": {},
"source": [
"## Tools\n",
"\n",
"Tools are an incredibly powerful feature provided by the frontier LLMs.\n",
"\n",
"With tools, you can write a function, and have the LLM call that function as part of its response.\n",
"\n",
"Sounds almost spooky.. we're giving it the power to run code on our machine?\n",
"\n",
"Well, kinda."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0696acb1-0b05-4dc2-80d5-771be04f1fb2",
"metadata": {},
"outputs": [],
"source": [
"# Let's start by making a useful function\n",
"\n",
"ticket_prices = {\"london\": \"$799\", \"paris\": \"$899\", \"tokyo\": \"$1400\", \"berlin\": \"$499\"}\n",
"\n",
"def get_ticket_price(destination_city):\n",
" print(f\"Tool get_ticket_price called for {destination_city}\")\n",
" city = destination_city.lower()\n",
" return ticket_prices.get(city, \"Unknown\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80ca4e09-6287-4d3f-997d-fa6afbcf6c85",
"metadata": {},
"outputs": [],
"source": [
"get_ticket_price(\"Berlin\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4afceded-7178-4c05-8fa6-9f2085e6a344",
"metadata": {},
"outputs": [],
"source": [
"# There's a particular dictionary structure that's required to describe our function:\n",
"\n",
"price_function = {\n",
" \"name\": \"get_ticket_price\",\n",
" \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city that the customer wants to travel to\",\n",
" },\n",
" },\n",
" \"required\": [\"destination_city\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bdca8679-935f-4e7f-97e6-e71a4d4f228c",
"metadata": {},
"outputs": [],
"source": [
"# And this is included in a list of tools:\n",
"\n",
"tools = [{\"type\": \"function\", \"function\": price_function}]"
]
},
{
"cell_type": "markdown",
"id": "c3d3554f-b4e3-4ce7-af6f-68faa6dd2340",
"metadata": {},
"source": [
"## Getting OpenAI to use our Tool\n",
"\n",
"There's some fiddly stuff to allow OpenAI \"to call our tool\"\n",
"\n",
"What we actually do is give the LLM the opportunity to inform us that it wants us to run the tool.\n",
"\n",
"Here's how the new chat function looks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce9b0744-9c78-408d-b9df-9f6fd9ed78cf",
"metadata": {},
"outputs": [],
"source": [
"def chat(message, history):\n",
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
"\n",
" if response.choices[0].finish_reason==\"tool_calls\":\n",
" message = response.choices[0].message\n",
" response, city = handle_tool_call(message)\n",
" messages.append(message)\n",
" messages.append(response)\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
" \n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0992986-ea09-4912-a076-8e5603ee631f",
"metadata": {},
"outputs": [],
"source": [
"# We have to write that function handle_tool_call:\n",
"\n",
"def handle_tool_call(message):\n",
" tool_call = message.tool_calls[0]\n",
" arguments = json.loads(tool_call.function.arguments)\n",
" city = arguments.get('destination_city')\n",
" price = get_ticket_price(city)\n",
" response = {\n",
" \"role\": \"tool\",\n",
" \"content\": json.dumps({\"destination_city\": city,\"price\": price}),\n",
" \"tool_call_id\": tool_call.id\n",
" }\n",
" return response, city"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4be8a71-b19e-4c2f-80df-f59ff2661f14",
"metadata": {},
"outputs": [],
"source": [
"gr.ChatInterface(fn=chat, type=\"messages\").launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|