Spaces:
Sleeping
Sleeping
Commit
·
55d1cfb
1
Parent(s):
26ebe22
fix exmaple issue
Browse files- Gradio_UI.py +27 -67
- UIexample.py +221 -0
- UIexample2.py +230 -0
- UItest.py +55 -0
- __pycache__/Gradio_UI.cpython-312.pyc +0 -0
- app.py +2 -2
- tools/QCMTool.py +4 -2
- tools/__pycache__/QCMTool.cpython-312.pyc +0 -0
Gradio_UI.py
CHANGED
@@ -192,61 +192,16 @@ class GradioUI:
|
|
192 |
|
193 |
def interact_with_agent(self, prompt, messages):
|
194 |
import gradio as gr
|
195 |
-
|
196 |
messages.append(gr.ChatMessage(role="user", content=prompt))
|
197 |
yield messages
|
|
|
|
|
198 |
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
|
199 |
messages.append(msg)
|
200 |
yield messages
|
201 |
yield messages
|
202 |
|
203 |
-
def upload_file(
|
204 |
-
self,
|
205 |
-
file,
|
206 |
-
file_uploads_log,
|
207 |
-
allowed_file_types=[
|
208 |
-
"application/pdf",
|
209 |
-
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
210 |
-
"text/plain",
|
211 |
-
],
|
212 |
-
):
|
213 |
-
"""
|
214 |
-
Handle file uploads, default allowed types are .pdf, .docx, and .txt
|
215 |
-
"""
|
216 |
-
import gradio as gr
|
217 |
-
|
218 |
-
if file is None:
|
219 |
-
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
|
220 |
-
|
221 |
-
try:
|
222 |
-
mime_type, _ = mimetypes.guess_type(file.name)
|
223 |
-
except Exception as e:
|
224 |
-
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
|
225 |
-
|
226 |
-
if mime_type not in allowed_file_types:
|
227 |
-
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
|
228 |
-
|
229 |
-
# Sanitize file name
|
230 |
-
original_name = os.path.basename(file.name)
|
231 |
-
sanitized_name = re.sub(
|
232 |
-
r"[^\w\-.]", "_", original_name
|
233 |
-
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
234 |
-
|
235 |
-
type_to_ext = {}
|
236 |
-
for ext, t in mimetypes.types_map.items():
|
237 |
-
if t not in type_to_ext:
|
238 |
-
type_to_ext[t] = ext
|
239 |
-
|
240 |
-
# Ensure the extension correlates to the mime type
|
241 |
-
sanitized_name = sanitized_name.split(".")[:-1]
|
242 |
-
sanitized_name.append("" + type_to_ext[mime_type])
|
243 |
-
sanitized_name = "".join(sanitized_name)
|
244 |
-
|
245 |
-
# Save the uploaded file to the specified folder
|
246 |
-
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
247 |
-
shutil.copy(file.name, file_path)
|
248 |
|
249 |
-
return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
|
250 |
|
251 |
def log_user_message(self, text_input, file_uploads_log):
|
252 |
return (
|
@@ -259,19 +214,26 @@ class GradioUI:
|
|
259 |
"",
|
260 |
)
|
261 |
|
262 |
-
|
263 |
-
if x.value["text"] is not None:
|
264 |
-
history.append((x.value["text"], None))
|
265 |
-
if "files" in x.value:
|
266 |
-
if isinstance(x.value["files"], list):
|
267 |
-
for file in x.value["files"]:
|
268 |
-
history.append((file, None))
|
269 |
-
else:
|
270 |
-
history.append((x.value["files"], None))
|
271 |
-
return history
|
272 |
|
273 |
def launch(self, **kwargs):
|
274 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
examples = [
|
276 |
{
|
277 |
"text": "Calculate the VaR for returns: 0.1, -0.2, 0.05, -0.15, 0.3", # Message to populate
|
@@ -302,24 +264,22 @@ class GradioUI:
|
|
302 |
resizeable=True,
|
303 |
scale=1,
|
304 |
# Description
|
305 |
-
examples=examples,
|
|
|
|
|
306 |
)
|
307 |
-
|
|
|
308 |
# If an upload folder is provided, enable the upload feature
|
309 |
-
if self.file_upload_folder is not None:
|
310 |
-
upload_file = gr.File(label="Upload a file")
|
311 |
-
upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
|
312 |
-
upload_file.change(
|
313 |
-
self.upload_file,
|
314 |
-
[upload_file, file_uploads_log],
|
315 |
-
[upload_status, file_uploads_log],
|
316 |
-
)
|
317 |
text_input = gr.Textbox(lines=1, label="Chat Message")
|
318 |
text_input.submit(
|
319 |
self.log_user_message,
|
320 |
[text_input, file_uploads_log],
|
321 |
[stored_messages, text_input],
|
322 |
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])
|
|
|
|
|
|
|
323 |
|
324 |
|
325 |
demo.launch(debug=True, share=True, **kwargs)
|
|
|
192 |
|
193 |
def interact_with_agent(self, prompt, messages):
|
194 |
import gradio as gr
|
|
|
195 |
messages.append(gr.ChatMessage(role="user", content=prompt))
|
196 |
yield messages
|
197 |
+
print("ok")
|
198 |
+
print(messages)
|
199 |
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
|
200 |
messages.append(msg)
|
201 |
yield messages
|
202 |
yield messages
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
|
|
205 |
|
206 |
def log_user_message(self, text_input, file_uploads_log):
|
207 |
return (
|
|
|
214 |
"",
|
215 |
)
|
216 |
|
217 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
def launch(self, **kwargs):
|
220 |
import gradio as gr
|
221 |
+
def append_example_message(x: gr.SelectData, messages):
|
222 |
+
if x.value["text"] is not None:
|
223 |
+
message = x.value["text"]
|
224 |
+
if "files" in x.value:
|
225 |
+
if isinstance(x.value["files"], list):
|
226 |
+
message = "Here are the files: "
|
227 |
+
for file in x.value["files"]:
|
228 |
+
message += f"{file}, "
|
229 |
+
else:
|
230 |
+
message = x.value["files"]
|
231 |
+
messages.append(gr.ChatMessage(role="user", content=message))
|
232 |
+
#print(message)
|
233 |
+
#messages=message
|
234 |
+
#return messages
|
235 |
+
return message
|
236 |
+
|
237 |
examples = [
|
238 |
{
|
239 |
"text": "Calculate the VaR for returns: 0.1, -0.2, 0.05, -0.15, 0.3", # Message to populate
|
|
|
264 |
resizeable=True,
|
265 |
scale=1,
|
266 |
# Description
|
267 |
+
examples=examples,
|
268 |
+
placeholder="""<h1>FRM Study chatbot</h1>""",
|
269 |
+
# Example inputs
|
270 |
)
|
271 |
+
|
272 |
+
|
273 |
# If an upload folder is provided, enable the upload feature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
text_input = gr.Textbox(lines=1, label="Chat Message")
|
275 |
text_input.submit(
|
276 |
self.log_user_message,
|
277 |
[text_input, file_uploads_log],
|
278 |
[stored_messages, text_input],
|
279 |
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])
|
280 |
+
chatbot.example_select(append_example_message, chatbot, text_input)#.then(self.interact_with_agent, chatbot, chatbot)
|
281 |
+
|
282 |
+
|
283 |
|
284 |
|
285 |
demo.launch(debug=True, share=True, **kwargs)
|
UIexample.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations as _annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from httpx import AsyncClient
|
11 |
+
from pydantic_ai import Agent, ModelRetry, RunContext
|
12 |
+
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class Deps:
|
19 |
+
client: AsyncClient
|
20 |
+
weather_api_key: str | None
|
21 |
+
geo_api_key: str | None
|
22 |
+
|
23 |
+
|
24 |
+
weather_agent = Agent(
|
25 |
+
"openai:gpt-4o",
|
26 |
+
system_prompt="You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.",
|
27 |
+
deps_type=Deps,
|
28 |
+
retries=2,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
@weather_agent.tool
|
33 |
+
async def get_lat_lng(
|
34 |
+
ctx: RunContext[Deps], location_description: str
|
35 |
+
) -> dict[str, float]:
|
36 |
+
"""Get the latitude and longitude of a location.
|
37 |
+
Args:
|
38 |
+
ctx: The context.
|
39 |
+
location_description: A description of a location.
|
40 |
+
"""
|
41 |
+
if ctx.deps.geo_api_key is None:
|
42 |
+
# if no API key is provided, return a dummy response (London)
|
43 |
+
return {"lat": 51.1, "lng": -0.1}
|
44 |
+
|
45 |
+
params = {
|
46 |
+
"q": location_description,
|
47 |
+
"api_key": ctx.deps.geo_api_key,
|
48 |
+
}
|
49 |
+
r = await ctx.deps.client.get("https://geocode.maps.co/search", params=params)
|
50 |
+
r.raise_for_status()
|
51 |
+
data = r.json()
|
52 |
+
|
53 |
+
if data:
|
54 |
+
return {"lat": data[0]["lat"], "lng": data[0]["lon"]}
|
55 |
+
else:
|
56 |
+
raise ModelRetry("Could not find the location")
|
57 |
+
|
58 |
+
|
59 |
+
@weather_agent.tool
|
60 |
+
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]:
|
61 |
+
"""Get the weather at a location.
|
62 |
+
Args:
|
63 |
+
ctx: The context.
|
64 |
+
lat: Latitude of the location.
|
65 |
+
lng: Longitude of the location.
|
66 |
+
"""
|
67 |
+
if ctx.deps.weather_api_key is None:
|
68 |
+
# if no API key is provided, return a dummy response
|
69 |
+
return {"temperature": "21 °C", "description": "Sunny"}
|
70 |
+
|
71 |
+
params = {
|
72 |
+
"apikey": ctx.deps.weather_api_key,
|
73 |
+
"location": f"{lat},{lng}",
|
74 |
+
"units": "metric",
|
75 |
+
}
|
76 |
+
r = await ctx.deps.client.get(
|
77 |
+
"https://api.tomorrow.io/v4/weather/realtime", params=params
|
78 |
+
)
|
79 |
+
r.raise_for_status()
|
80 |
+
data = r.json()
|
81 |
+
|
82 |
+
values = data["data"]["values"]
|
83 |
+
# https://docs.tomorrow.io/reference/data-layers-weather-codes
|
84 |
+
code_lookup = {
|
85 |
+
1000: "Clear, Sunny",
|
86 |
+
1100: "Mostly Clear",
|
87 |
+
1101: "Partly Cloudy",
|
88 |
+
1102: "Mostly Cloudy",
|
89 |
+
1001: "Cloudy",
|
90 |
+
2000: "Fog",
|
91 |
+
2100: "Light Fog",
|
92 |
+
4000: "Drizzle",
|
93 |
+
4001: "Rain",
|
94 |
+
4200: "Light Rain",
|
95 |
+
4201: "Heavy Rain",
|
96 |
+
5000: "Snow",
|
97 |
+
5001: "Flurries",
|
98 |
+
5100: "Light Snow",
|
99 |
+
5101: "Heavy Snow",
|
100 |
+
6000: "Freezing Drizzle",
|
101 |
+
6001: "Freezing Rain",
|
102 |
+
6200: "Light Freezing Rain",
|
103 |
+
6201: "Heavy Freezing Rain",
|
104 |
+
7000: "Ice Pellets",
|
105 |
+
7101: "Heavy Ice Pellets",
|
106 |
+
7102: "Light Ice Pellets",
|
107 |
+
8000: "Thunderstorm",
|
108 |
+
}
|
109 |
+
return {
|
110 |
+
"temperature": f'{values["temperatureApparent"]:0.0f}°C',
|
111 |
+
"description": code_lookup.get(values["weatherCode"], "Unknown"),
|
112 |
+
}
|
113 |
+
|
114 |
+
|
115 |
+
TOOL_TO_DISPLAY_NAME = {"get_lat_lng": "Geocoding API", "get_weather": "Weather API"}
|
116 |
+
|
117 |
+
client = AsyncClient()
|
118 |
+
weather_api_key = os.getenv("WEATHER_API_KEY")
|
119 |
+
# create a free API key at https://geocode.maps.co/
|
120 |
+
geo_api_key = os.getenv("GEO_API_KEY")
|
121 |
+
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key)
|
122 |
+
|
123 |
+
|
124 |
+
async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list):
|
125 |
+
chatbot.append({"role": "user", "content": prompt})
|
126 |
+
yield gr.Textbox(interactive=False, value=""), chatbot, gr.skip()
|
127 |
+
async with weather_agent.run_stream(
|
128 |
+
prompt, deps=deps, message_history=past_messages
|
129 |
+
) as result:
|
130 |
+
for message in result.new_messages():
|
131 |
+
past_messages.append(message)
|
132 |
+
if isinstance(message, ModelStructuredResponse):
|
133 |
+
for call in message.calls:
|
134 |
+
gr_message = {
|
135 |
+
"role": "assistant",
|
136 |
+
"content": "",
|
137 |
+
"metadata": {
|
138 |
+
"title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}",
|
139 |
+
"id": call.tool_id,
|
140 |
+
},
|
141 |
+
}
|
142 |
+
chatbot.append(gr_message)
|
143 |
+
if isinstance(message, ToolReturn):
|
144 |
+
for gr_message in chatbot:
|
145 |
+
if gr_message.get("metadata", {}).get("id", "") == message.tool_id:
|
146 |
+
gr_message["content"] = f"Output: {json.dumps(message.content)}"
|
147 |
+
yield gr.skip(), chatbot, gr.skip()
|
148 |
+
chatbot.append({"role": "assistant", "content": ""})
|
149 |
+
async for message in result.stream_text():
|
150 |
+
chatbot[-1]["content"] = message
|
151 |
+
yield gr.skip(), chatbot, gr.skip()
|
152 |
+
data = await result.get_data()
|
153 |
+
past_messages.append(ModelTextResponse(content=data))
|
154 |
+
yield gr.Textbox(interactive=True), gr.skip(), past_messages
|
155 |
+
|
156 |
+
|
157 |
+
async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData):
|
158 |
+
new_history = chatbot[: retry_data.index]
|
159 |
+
previous_prompt = chatbot[retry_data.index]["content"]
|
160 |
+
past_messages = past_messages[: retry_data.index]
|
161 |
+
async for update in stream_from_agent(previous_prompt, new_history, past_messages):
|
162 |
+
yield update
|
163 |
+
|
164 |
+
|
165 |
+
def undo(chatbot, past_messages: list, undo_data: gr.UndoData):
|
166 |
+
new_history = chatbot[: undo_data.index]
|
167 |
+
past_messages = past_messages[: undo_data.index]
|
168 |
+
return chatbot[undo_data.index]["content"], new_history, past_messages
|
169 |
+
|
170 |
+
|
171 |
+
def select_data(message: gr.SelectData) -> str:
|
172 |
+
return message.value["text"]
|
173 |
+
|
174 |
+
|
175 |
+
with gr.Blocks() as demo:
|
176 |
+
gr.HTML(
|
177 |
+
"""
|
178 |
+
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
|
179 |
+
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
|
180 |
+
<div>
|
181 |
+
<h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1>
|
182 |
+
<h3 style="margin: 0 0 0.5rem 0">
|
183 |
+
This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast.
|
184 |
+
</h3>
|
185 |
+
<h3 style="margin: 0">
|
186 |
+
Feel free to ask for help with any other questions you have about your trip!
|
187 |
+
</h3>
|
188 |
+
</div>
|
189 |
+
</div>
|
190 |
+
"""
|
191 |
+
)
|
192 |
+
past_messages = gr.State([])
|
193 |
+
chatbot = gr.Chatbot(
|
194 |
+
label="Packing Assistant",
|
195 |
+
type="messages",
|
196 |
+
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"),
|
197 |
+
examples=[
|
198 |
+
{"text": "I am going to Paris for the holidays, what should I pack?"},
|
199 |
+
{"text": "I am going to Tokyo this week."},
|
200 |
+
],
|
201 |
+
)
|
202 |
+
with gr.Row():
|
203 |
+
prompt = gr.Textbox(
|
204 |
+
lines=1,
|
205 |
+
show_label=False,
|
206 |
+
placeholder="I am planning a trip to Miami, what should I pack?",
|
207 |
+
)
|
208 |
+
generation = prompt.submit(
|
209 |
+
stream_from_agent,
|
210 |
+
inputs=[prompt, chatbot, past_messages],
|
211 |
+
outputs=[prompt, chatbot, past_messages],
|
212 |
+
)
|
213 |
+
chatbot.example_select(select_data, None, [prompt])
|
214 |
+
chatbot.retry(
|
215 |
+
handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages]
|
216 |
+
)
|
217 |
+
chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages])
|
218 |
+
|
219 |
+
|
220 |
+
if __name__ == "__main__":
|
221 |
+
demo.launch()
|
UIexample2.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://huggingface.co/spaces/vonliechti/SQuAD_Agent_Experiment/blob/main/app.py
|
2 |
+
import gradio as gr
|
3 |
+
from gradio import ChatMessage
|
4 |
+
from utils import stream_from_transformers_agent
|
5 |
+
from gradio.context import Context
|
6 |
+
from gradio import Request
|
7 |
+
import pickle
|
8 |
+
import os
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from agent import get_agent, DEFAULT_TASK_SOLVING_TOOLBOX
|
11 |
+
from transformers.agents import (
|
12 |
+
DuckDuckGoSearchTool,
|
13 |
+
ImageQuestionAnsweringTool,
|
14 |
+
VisitWebpageTool,
|
15 |
+
)
|
16 |
+
from tools.text_to_image import TextToImageTool
|
17 |
+
from PIL import Image
|
18 |
+
from transformers import load_tool
|
19 |
+
from prompts import (
|
20 |
+
DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT,
|
21 |
+
FOCUSED_SQUAD_REACT_CODE_SYSTEM_PROMPT,
|
22 |
+
)
|
23 |
+
from pygments.formatters import HtmlFormatter
|
24 |
+
|
25 |
+
|
26 |
+
load_dotenv()
|
27 |
+
|
28 |
+
SESSION_PERSISTENCE_ENABLED = os.getenv("SESSION_PERSISTENCE_ENABLED", False)
|
29 |
+
|
30 |
+
sessions_path = "sessions.pkl"
|
31 |
+
sessions = (
|
32 |
+
pickle.load(open(sessions_path, "rb"))
|
33 |
+
if SESSION_PERSISTENCE_ENABLED and os.path.exists(sessions_path)
|
34 |
+
else {}
|
35 |
+
)
|
36 |
+
|
37 |
+
# If currently hosted on HuggingFace Spaces, use the default model, otherwise use the local model
|
38 |
+
model_name = (
|
39 |
+
"meta-llama/Meta-Llama-3.1-8B-Instruct"
|
40 |
+
if os.getenv("SPACE_ID") is not None
|
41 |
+
else "http://localhost:1234/v1"
|
42 |
+
)
|
43 |
+
|
44 |
+
"""
|
45 |
+
The ImageQuestionAnsweringTool from Transformers Agents 2.0 has a bug where
|
46 |
+
it said it accepts the path to an image, but it does not.
|
47 |
+
This class uses the adapter pattern to fix the issue, in a way that may be
|
48 |
+
compatible with future versions of the tool even if the bug is fixed.
|
49 |
+
"""
|
50 |
+
class FixImageQuestionAnsweringTool(ImageQuestionAnsweringTool):
|
51 |
+
def __init__(self, *args, **kwargs):
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
def encode(self, image: "Image | str", question: str):
|
55 |
+
if isinstance(image, str):
|
56 |
+
image = Image.open(image)
|
57 |
+
return super().encode(image, question)
|
58 |
+
|
59 |
+
"""
|
60 |
+
The app version of the agent has access to additional tools that are not available
|
61 |
+
during benchmarking. We chose this approach to focus benchmarking on the agent's
|
62 |
+
ability to solve questions about the SQuAD dataset, without the help of general
|
63 |
+
knowledge available on the web. For the purposes of the project, the demo
|
64 |
+
app has access to additional tools to provide a more interactive and engaging experience.
|
65 |
+
"""
|
66 |
+
ADDITIONAL_TOOLS = [
|
67 |
+
DuckDuckGoSearchTool(),
|
68 |
+
VisitWebpageTool(),
|
69 |
+
FixImageQuestionAnsweringTool(),
|
70 |
+
load_tool("speech_to_text"),
|
71 |
+
load_tool("text_to_speech"),
|
72 |
+
load_tool("translation"),
|
73 |
+
TextToImageTool(),
|
74 |
+
]
|
75 |
+
|
76 |
+
# Add image tools to the default task solving toolbox, for a more visually interactive experience
|
77 |
+
TASK_SOLVING_TOOLBOX = DEFAULT_TASK_SOLVING_TOOLBOX + ADDITIONAL_TOOLS
|
78 |
+
|
79 |
+
# Using the focused prompt, which was the top-performing prompt during benchmarking
|
80 |
+
system_prompt = FOCUSED_SQUAD_REACT_CODE_SYSTEM_PROMPT
|
81 |
+
|
82 |
+
agent = get_agent(
|
83 |
+
model_name=model_name,
|
84 |
+
toolbox=TASK_SOLVING_TOOLBOX,
|
85 |
+
system_prompt=system_prompt,
|
86 |
+
use_openai=True, # Use OpenAI instead of a local or HF model as the base LLM engine
|
87 |
+
)
|
88 |
+
|
89 |
+
def append_example_message(x: gr.SelectData, messages):
|
90 |
+
if x.value["text"] is not None:
|
91 |
+
message = x.value["text"]
|
92 |
+
if "files" in x.value:
|
93 |
+
if isinstance(x.value["files"], list):
|
94 |
+
message = "Here are the files: "
|
95 |
+
for file in x.value["files"]:
|
96 |
+
message += f"{file}, "
|
97 |
+
else:
|
98 |
+
message = x.value["files"]
|
99 |
+
messages.append(ChatMessage(role="user", content=message))
|
100 |
+
return messages
|
101 |
+
|
102 |
+
|
103 |
+
def add_message(message, messages):
|
104 |
+
messages.append(ChatMessage(role="user", content=message))
|
105 |
+
return messages
|
106 |
+
|
107 |
+
|
108 |
+
def interact_with_agent(messages, request: Request):
|
109 |
+
session_hash = request.session_hash
|
110 |
+
prompt = messages[-1]["content"]
|
111 |
+
agent.logs = sessions.get(session_hash + "_logs", [])
|
112 |
+
yield messages, gr.update(
|
113 |
+
value="<center><h1>Thinking...</h1></center>", visible=True
|
114 |
+
)
|
115 |
+
for msg in stream_from_transformers_agent(agent, prompt):
|
116 |
+
if isinstance(msg, ChatMessage):
|
117 |
+
messages.append(msg)
|
118 |
+
yield messages, gr.update(visible=True)
|
119 |
+
else:
|
120 |
+
yield messages, gr.update(
|
121 |
+
value=f"<center><h1>{msg}</h1></center>", visible=True
|
122 |
+
)
|
123 |
+
yield messages, gr.update(value="<center><h1>Idle</h1></center>", visible=False)
|
124 |
+
|
125 |
+
|
126 |
+
def persist(component):
|
127 |
+
|
128 |
+
def resume_session(value, request: Request):
|
129 |
+
session_hash = request.session_hash
|
130 |
+
print(f"Resuming session for {session_hash}")
|
131 |
+
state = sessions.get(session_hash, value)
|
132 |
+
agent.logs = sessions.get(session_hash + "_logs", [])
|
133 |
+
return state
|
134 |
+
|
135 |
+
def update_session(value, request: Request):
|
136 |
+
session_hash = request.session_hash
|
137 |
+
print(f"Updating persisted session state for {session_hash}")
|
138 |
+
sessions[session_hash] = value
|
139 |
+
sessions[session_hash + "_logs"] = agent.logs
|
140 |
+
if SESSION_PERSISTENCE_ENABLED:
|
141 |
+
pickle.dump(sessions, open(sessions_path, "wb"))
|
142 |
+
|
143 |
+
Context.root_block.load(resume_session, inputs=[component], outputs=component)
|
144 |
+
component.change(update_session, inputs=[component], outputs=None)
|
145 |
+
|
146 |
+
return component
|
147 |
+
|
148 |
+
|
149 |
+
from gradio.components import (
|
150 |
+
Component as GradioComponent,
|
151 |
+
)
|
152 |
+
from gradio.components.chatbot import (
|
153 |
+
Chatbot,
|
154 |
+
FileDataDict,
|
155 |
+
FileData,
|
156 |
+
ComponentMessage,
|
157 |
+
FileMessage,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
class CleanChatBot(Chatbot):
|
162 |
+
def __init__(self, **kwargs):
|
163 |
+
super().__init__(**kwargs)
|
164 |
+
|
165 |
+
def _postprocess_content(
|
166 |
+
self,
|
167 |
+
chat_message: (
|
168 |
+
str | tuple | list | FileDataDict | FileData | GradioComponent | None
|
169 |
+
),
|
170 |
+
) -> str | FileMessage | ComponentMessage | None:
|
171 |
+
response = super()._postprocess_content(chat_message)
|
172 |
+
print(f"Post processing content: {response}")
|
173 |
+
if isinstance(response, ComponentMessage):
|
174 |
+
print(f"Setting open to False for {response}")
|
175 |
+
response.props["open"] = False
|
176 |
+
return response
|
177 |
+
|
178 |
+
|
179 |
+
with gr.Blocks(
|
180 |
+
fill_height=True,
|
181 |
+
css=".gradio-container .message .content {text-align: left;}"
|
182 |
+
+ HtmlFormatter().get_style_defs(".highlight"),
|
183 |
+
) as demo:
|
184 |
+
state = gr.State()
|
185 |
+
inner_monologue_component = gr.Markdown(
|
186 |
+
"""<h2>Inner Monologue</h2>""", visible=False
|
187 |
+
)
|
188 |
+
chatbot = persist(
|
189 |
+
gr.Chatbot(
|
190 |
+
value=[],
|
191 |
+
label="SQuAD Agent",
|
192 |
+
type="messages",
|
193 |
+
avatar_images=(
|
194 |
+
None,
|
195 |
+
"SQuAD.png",
|
196 |
+
),
|
197 |
+
scale=1,
|
198 |
+
autoscroll=True,
|
199 |
+
show_copy_all_button=True,
|
200 |
+
show_copy_button=True,
|
201 |
+
placeholder="""<h1>SQuAD Agent</h1>
|
202 |
+
<h2>I am your friendly guide to the Stanford Question and Answer Dataset (SQuAD).</h2>
|
203 |
+
<h2>You can ask me questions about the dataset. You can also ask me to create images
|
204 |
+
to help illustrate the topics under discussion, or expand the discussion beyond the dataset.</h2>
|
205 |
+
""",
|
206 |
+
examples=[
|
207 |
+
{
|
208 |
+
"text": "What is on top of the Notre Dame building?",
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"text": "What is the Olympic Torch made of?",
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"text": "Draw a picture of whatever is on top of the Notre Dame building.",
|
215 |
+
},
|
216 |
+
],
|
217 |
+
)
|
218 |
+
)
|
219 |
+
text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
|
220 |
+
chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
|
221 |
+
bot_msg = chat_msg.then(
|
222 |
+
interact_with_agent, [chatbot], [chatbot, inner_monologue_component]
|
223 |
+
)
|
224 |
+
text_input.submit(lambda: "", None, text_input)
|
225 |
+
chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
|
226 |
+
interact_with_agent, [chatbot], [chatbot, inner_monologue_component]
|
227 |
+
)
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
demo.launch()
|
UItest.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
# Multimodal Chatbot demo that shows support for examples (example messages shown within the chatbot).
|
4 |
+
|
5 |
+
def print_like_dislike(x: gr.LikeData):
|
6 |
+
print(x.index, x.value, x.liked)
|
7 |
+
|
8 |
+
def add_message(history, message):
|
9 |
+
for x in message["files"]:
|
10 |
+
history.append(((x,), None))
|
11 |
+
if message["text"] is not None:
|
12 |
+
history.append((message["text"], None))
|
13 |
+
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
14 |
+
|
15 |
+
def append_example_message(x: gr.SelectData, history):
|
16 |
+
if x.value["text"] is not None:
|
17 |
+
history.append((x.value["text"], None))
|
18 |
+
if "files" in x.value:
|
19 |
+
if isinstance(x.value["files"], list):
|
20 |
+
for file in x.value["files"]:
|
21 |
+
history.append((file, None))
|
22 |
+
else:
|
23 |
+
history.append((x.value["files"], None))
|
24 |
+
return history
|
25 |
+
|
26 |
+
def respond(history):
|
27 |
+
history[-1][1] = "Cool!"
|
28 |
+
return history
|
29 |
+
|
30 |
+
with gr.Blocks(fill_height=True) as demo:
|
31 |
+
chatbot = gr.Chatbot(
|
32 |
+
elem_id="chatbot",
|
33 |
+
bubble_full_width=False,
|
34 |
+
scale=1,
|
35 |
+
placeholder='<h1 style="font-weight: bold; color: #FFFFFF; text-align: center; font-size: 48px; font-family: Arial, sans-serif;">Welcome to Gradio!</h1>',
|
36 |
+
examples=[{"icon": os.path.join(os.path.dirname(__file__), "files/avatar.png"), "display_text": "Display Text Here!", "text": "Try this example with this audio.", "files": [os.path.join(os.path.dirname(__file__), "files/cantina.wav")]},
|
37 |
+
{"text": "Try this example with this image.", "files": [os.path.join(os.path.dirname(__file__), "files/avatar.png")]},
|
38 |
+
{"text": "This is just text, no files!"},
|
39 |
+
{"text": "Try this example with this image.", "files": [os.path.join(os.path.dirname(__file__), "files/avatar.png"), os.path.join(os.path.dirname(__file__), "files/avatar.png")]},
|
40 |
+
{"text": "Try this example with this Audio.", "files": [os.path.join(os.path.dirname(__file__), "files/cantina.wav")]}]
|
41 |
+
)
|
42 |
+
|
43 |
+
chat_input = gr.MultimodalTextbox(interactive=True,
|
44 |
+
file_count="multiple",
|
45 |
+
placeholder="Enter message or upload file...", show_label=False)
|
46 |
+
|
47 |
+
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
48 |
+
bot_msg = chat_msg.then(respond, chatbot, chatbot, api_name="bot_response")
|
49 |
+
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
|
50 |
+
|
51 |
+
chatbot.like(print_like_dislike, None, None)
|
52 |
+
chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(respond, chatbot, chatbot, api_name="respond")
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
demo.launch()
|
__pycache__/Gradio_UI.cpython-312.pyc
CHANGED
Binary files a/__pycache__/Gradio_UI.cpython-312.pyc and b/__pycache__/Gradio_UI.cpython-312.pyc differ
|
|
app.py
CHANGED
@@ -115,7 +115,7 @@ def provide_my_information(query: str) -> str:
|
|
115 |
return "I'm sorry, I don't have information on that. Please ask about my name, location, occupation, education, skills, hobbies, or contact details."
|
116 |
|
117 |
|
118 |
-
qcm_tool = QCMTool("info/questions.json")
|
119 |
final_answer = FinalAnswerTool()
|
120 |
visit_webpage = VisitWebpageTool()
|
121 |
web_search = DuckDuckGoSearchTool()
|
@@ -141,7 +141,7 @@ with open("prompts.yaml", 'r') as stream:
|
|
141 |
|
142 |
agent = CodeAgent(
|
143 |
model=model,
|
144 |
-
tools=[final_answer,calculate_risk_metrics,
|
145 |
max_steps=6,
|
146 |
verbosity_level=1,
|
147 |
grammar=None,
|
|
|
115 |
return "I'm sorry, I don't have information on that. Please ask about my name, location, occupation, education, skills, hobbies, or contact details."
|
116 |
|
117 |
|
118 |
+
#qcm_tool = QCMTool("info/questions.json")
|
119 |
final_answer = FinalAnswerTool()
|
120 |
visit_webpage = VisitWebpageTool()
|
121 |
web_search = DuckDuckGoSearchTool()
|
|
|
141 |
|
142 |
agent = CodeAgent(
|
143 |
model=model,
|
144 |
+
tools=[final_answer,calculate_risk_metrics,visit_webpage,web_search,provide_my_information], ## add your tools here (don't remove final answer)
|
145 |
max_steps=6,
|
146 |
verbosity_level=1,
|
147 |
grammar=None,
|
tools/QCMTool.py
CHANGED
@@ -3,6 +3,7 @@ from smolagents.tools import Tool
|
|
3 |
import json
|
4 |
import random
|
5 |
|
|
|
6 |
class QCMTool(Tool):
|
7 |
"""
|
8 |
A tool for running multiple-choice question (QCM) quizzes.
|
@@ -106,10 +107,11 @@ class QCMTool(Tool):
|
|
106 |
else:
|
107 |
return f"Incorrect! 😞\nExplanation: {explanation}"
|
108 |
|
|
|
109 |
if __name__ == "__main__":
|
110 |
# Initialize the QCM tool
|
111 |
qcm_tool = QCMTool(json_file="../info/questions.json")
|
112 |
-
question=qcm_tool._pick_random_question()
|
113 |
print(question)
|
114 |
|
115 |
# Simulate a user answering 'A'
|
@@ -117,4 +119,4 @@ if __name__ == "__main__":
|
|
117 |
result = qcm_tool.forward(user_answer="A")
|
118 |
print(result)
|
119 |
except ValueError as e:
|
120 |
-
print(f"Error: {e}")
|
|
|
3 |
import json
|
4 |
import random
|
5 |
|
6 |
+
|
7 |
class QCMTool(Tool):
|
8 |
"""
|
9 |
A tool for running multiple-choice question (QCM) quizzes.
|
|
|
107 |
else:
|
108 |
return f"Incorrect! 😞\nExplanation: {explanation}"
|
109 |
|
110 |
+
|
111 |
if __name__ == "__main__":
|
112 |
# Initialize the QCM tool
|
113 |
qcm_tool = QCMTool(json_file="../info/questions.json")
|
114 |
+
question = qcm_tool._pick_random_question()
|
115 |
print(question)
|
116 |
|
117 |
# Simulate a user answering 'A'
|
|
|
119 |
result = qcm_tool.forward(user_answer="A")
|
120 |
print(result)
|
121 |
except ValueError as e:
|
122 |
+
print(f"Error: {e}")
|
tools/__pycache__/QCMTool.cpython-312.pyc
CHANGED
Binary files a/tools/__pycache__/QCMTool.cpython-312.pyc and b/tools/__pycache__/QCMTool.cpython-312.pyc differ
|
|