wt002 commited on
Commit
a8fe6cd
·
verified ·
1 Parent(s): 7c40d5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -124
app.py CHANGED
@@ -1,146 +1,321 @@
1
  import os
2
- from dotenv import load_dotenv
3
  import gradio as gr
 
 
 
 
 
4
  import requests
5
-
6
- from typing import List, Dict, Union, Optional
7
  import pandas as pd
8
- import wikipediaapi
9
- import requests
10
- #from bs4 import BeautifulSoup
11
- import random
12
- import re
13
- from typing import Optional
14
- from datetime import datetime
15
- import google.generativeai as genai
 
 
 
 
16
 
17
  load_dotenv()
18
 
19
  # (Keep Constants as is)
20
  # --- Constants ---
21
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
22
-
23
 
24
  # --- Basic Agent Definition ---
 
 
25
 
26
- class BasicAgent:
27
- def __init__(self, model_name: str = "gemini-pro"):
28
- """
29
- Multi-modal agent powered by Google Gemini with:
30
- - Web search
31
- - Wikipedia access
32
- - Document processing
33
- """
34
- self.model = genai.GenerativeModel(model_name)
35
- self.wiki = wikipediaapi.Wikipedia('en')
36
- self.searx_url = "https://searx.space/search" # Public Searx instance
37
-
38
- print("BasicAgent initialized.")
39
-
40
- def __call__(self, question: str) -> str:
41
- print(f"Agent received question (first 50 chars): {question[:50]}...")
42
- fixed_answer = self.process_request(question)
43
- print(f"Agent returning answer: {fixed_answer}")
44
- return fixed_answer
45
-
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- def generate_response(self, prompt: str) -> str:
49
- """Get response from Gemini"""
50
- try:
51
- response = self.model.generate_content(prompt)
52
- return response.text
53
- except Exception as e:
54
- return f"Error generating response: {str(e)}"
55
-
56
- def web_search(self, query: str) -> List[Dict]:
57
- """Use SearxNG meta-search engine"""
58
- params = {
59
- "q": query,
60
- "format": "json",
61
- "engines": "google,bing,duckduckgo"
62
- }
63
- try:
64
- response = requests.get(self.searx_url, params=params)
65
- response.raise_for_status()
66
- return response.json().get("results", [])
67
- except requests.RequestException:
68
- return []
69
-
70
- def wikipedia_search(self, query: str) -> str:
71
- """Get Wikipedia summary"""
72
- page = self.wiki.page(query)
73
- return page.summary if page.exists() else "No Wikipedia page found"
74
-
75
- def process_document(self, file_path: str) -> str:
76
- """Handle PDF, Word, CSV, Excel files"""
77
- if not os.path.exists(file_path):
78
- return "File not found"
79
-
80
- ext = os.path.splitext(file_path)[1].lower()
81
-
82
- try:
83
- if ext == '.pdf':
84
- return self._process_pdf(file_path)
85
- elif ext in ('.doc', '.docx'):
86
- return self._process_word(file_path)
87
- elif ext == '.csv':
88
- return pd.read_csv(file_path).to_string()
89
- elif ext in ('.xls', '.xlsx'):
90
- return pd.read_excel(file_path).to_string()
91
- else:
92
- return "Unsupported file format"
93
- except Exception as e:
94
- return f"Error processing document: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- def _process_pdf(self, file_path: str) -> str:
97
- """Process PDF using Gemini's vision capability"""
98
- try:
99
- # For Gemini 1.5 or later which supports file uploads
100
- with open(file_path, "rb") as f:
101
- file = genai.upload_file(f)
102
- response = self.model.generate_content(
103
- ["Extract and summarize the key points from this document:", file]
104
- )
105
- return response.text
106
- except:
107
- # Fallback for older Gemini versions
108
- try:
109
- import PyPDF2
110
- with open(file_path, 'rb') as f:
111
- reader = PyPDF2.PdfReader(f)
112
- return "\n".join([page.extract_text() for page in reader.pages])
113
- except ImportError:
114
- return "PDF processing requires PyPDF2 (pip install PyPDF2)"
115
-
116
- def _process_word(self, file_path: str) -> str:
117
- """Process Word documents"""
118
- try:
119
- from docx import Document
120
- doc = Document(file_path)
121
- return "\n".join([para.text for para in doc.paragraphs])
122
- except ImportError:
123
- return "Word processing requires python-docx (pip install python-docx)"
124
 
125
- def process_request(self, request: Union[str, Dict]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  """
127
- Handle different request types:
128
- - Direct text queries
129
- - File processing requests
130
- - Complex multi-step requests
 
 
 
 
 
 
131
  """
132
- if isinstance(request, dict):
133
- if 'steps' in request:
134
- results = []
135
- for step in request['steps']:
136
- if step['type'] == 'search':
137
- results.append(self.web_search(step['query']))
138
- elif step['type'] == 'process':
139
- results.append(self.process_document(step['file']))
140
- return self.generate_response(f"Process these results: {results}")
141
- return "Unsupported request format"
142
-
143
- return self.generate_response(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
 
146
 
 
1
  import os
2
+ from typing import Annotated, Optional, TypedDict
3
  import gradio as gr
4
+ from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
5
+ from langchain_openai import ChatOpenAI
6
+ from langgraph.graph.message import add_messages
7
+ from langgraph.graph import StateGraph, START
8
+ from langgraph.prebuilt import tools_condition, ToolNode
9
  import requests
 
 
10
  import pandas as pd
11
+ from langchain.tools import Tool
12
+ from dotenv import load_dotenv
13
+
14
+ from arxiv_searcher import ArxivSearcher
15
+ from chess_algebraic_notation_retriever import ChessAlgebraicNotationMoveRetriever
16
+ from excel_file_reader import ExcelFileReader
17
+ from image_question_answer_tool import ImageQuestionAnswerTool
18
+ from python_code_question_answer_tool import PythonCodeQuestionAnswerTool
19
+ from tavily_searcher import TavilySearcher
20
+ from transcriber import Transcriber
21
+ from wikipedia_searcher import WikipediaSearcher
22
+ from youtube_video_question_answer_tool import YoutubeVideoQuestionAnswerTool
23
 
24
  load_dotenv()
25
 
26
  # (Keep Constants as is)
27
  # --- Constants ---
28
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
29
+ ASSOCIATED_FILE_ENDPOINT = f"{DEFAULT_API_URL}/files/"
30
 
31
  # --- Basic Agent Definition ---
32
+ # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
33
+ #search_tool = DuckDuckGoSearchRun()
34
 
35
+ #search_tool = DuckDuckGoSearcherTool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def retrieve_task_file(task_id: str) -> Optional[bytes]:
38
+ """
39
+ Retrieve the task file for a given task ID.
40
+ """
41
+ try:
42
+ response = requests.get(ASSOCIATED_FILE_ENDPOINT + task_id, timeout=15)
43
+ response.raise_for_status()
44
+ if response.status_code != 200:
45
+ print(f"Error fetching file: {response.status_code}")
46
+ return None
47
+ #print(f"Fetched file: {response.content}")
48
+ return response.content
49
+ except requests.exceptions.RequestException as e:
50
+ print(f"Error fetching file: {e}")
51
+ return None
52
+ except Exception as e:
53
+ print(f"An unexpected error occurred fetching file: {e}")
54
+ return None
55
 
56
+ def retrieve_next_chess_move_in_algebraic_notation(task_file_path: str, is_black_turn: bool) -> str:
57
+ """
58
+ Retrieve the next chess move in algebraic notation from an image path.
59
+ """
60
+ if task_file_path is None:
61
+ return "Error: Task file not found."
62
+ # Retrieve the next chess move in algebraic notation
63
+ next_chess_move = ChessAlgebraicNotationMoveRetriever().retrieve(task_file_path, is_black_turn)
64
+ return next_chess_move
65
+
66
+ # Initialize the tool
67
+ retrieve_next_chess_move_in_algebraic_notation_tool = Tool(
68
+ name="retrieve_next_chess_move_in_algebraic_notation",
69
+ func=retrieve_next_chess_move_in_algebraic_notation,
70
+ description="Retrieve the next chess move in algebraic notation from an image path."
71
+ )
72
+
73
+ def transcribe_audio(file_path: str) -> str:
74
+ if file_path is None:
75
+ return "Error: Audio path not found."
76
+ # Transcribe the audio
77
+ return Transcriber().transcribe(file_path)
78
+
79
+ # Initialize the tool
80
+ transcribe_audio_tool = Tool(
81
+ name="transcribe_audio",
82
+ func=transcribe_audio,
83
+ description="Transcribe the audio from an audio path."
84
+ )
85
+
86
+ # Initialize the tool
87
+ answer_python_code_tool = PythonCodeQuestionAnswerTool()
88
+
89
+ # Initialize the tool
90
+ answer_image_question_tool = ImageQuestionAnswerTool()
91
+
92
+ # Initialize the tool
93
+ answer_youtube_video_question_tool = YoutubeVideoQuestionAnswerTool()
94
+
95
+ '''def answer_youtube_video_question(youtube_video_url: str, question: str) -> str:
96
+ """
97
+ Answer the question based on the youtube video.
98
+ """
99
+ if youtube_video_url is None:
100
+ return "Error: Video not found."
101
+ # Download the video
102
+ video_path = YoutubeVideoDownloader().download_video(youtube_video_url)
103
+ # Answer the question
104
+ return VideoQuestionAnswer().answer(video_path, question)
105
+ # Initialize the tool
106
+ answer_youtube_video_question_tool = Tool(
107
+ name="answer_youtube_video_question",
108
+ func=answer_youtube_video_question,
109
+ description="Answer the question based on the youtube video."
110
+ )'''
111
+
112
+ def read_excel_file(file_path: str) -> str:
113
+ if file_path is None:
114
+ return "Error: File not found."
115
+ return ExcelFileReader().read_file(file_path)
116
+
117
+ # Initialize the tool
118
+ read_excel_file_tool = Tool(
119
+ name="read_excel_file",
120
+ func=read_excel_file,
121
+ description="Read the excel file."
122
+ )
123
+
124
+ # Initialize the tool
125
+ wikipedia_search_tool = Tool(
126
+ name="wikipedia_search",
127
+ func=WikipediaSearcher().search,
128
+ description="Search Wikipedia for a given query."
129
+ )
130
+
131
+ # Initialize the tool
132
+ arxiv_search_tool = Tool(
133
+ name="arxiv_search",
134
+ func=ArxivSearcher().search,
135
+ description="Search Arxiv for a given query."
136
+ )
137
+
138
+ tavily_search_tool = Tool(
139
+ name="tavily_search",
140
+ func=TavilySearcher().search,
141
+ description="Search the web for a given query."
142
+ )
143
+
144
+ def format_gaia_answer(answer: str) -> str:
145
+ llm = ChatOpenAI(model="o3-mini", openai_api_key=os.getenv("OPENAI_API_KEY"))
146
+ prompt = f"""
147
+ You are formatting answers for the GAIA benchmark, which requires responses to be concise and unambiguous.
148
+ Given the answer: {answer}
149
+ Return the answer in the correct GAIA format:
150
+ - If the answer is a single word or number, return it without any additional text or formatting.
151
+ - If the answer is a list, return a comma-separated list without any additional text or formatting.
152
+ - If the answer is a string, return it without any additional text or formatting.
153
+ Do not include any prefixes, dots, enumerations, explanations, or quotation marks.
154
+ Do not include any additional text or formatting.
155
+ """
156
+ response = llm.invoke(prompt)
157
+ # Delete double quotes
158
+ return response.content.strip().replace('"', '')
159
+
160
+ class AgentState(TypedDict):
161
+ # The document provided
162
+ messages: Annotated[list[AnyMessage], add_messages]
163
+ file_path: Optional[str]
164
+
165
+ class BasicAgent:
166
+ def __init__(self):
167
+ tools = [
168
+ tavily_search_tool,
169
+ arxiv_search_tool,
170
+ wikipedia_search_tool,
171
+ transcribe_audio_tool,
172
+ answer_python_code_tool,
173
+ answer_image_question_tool,
174
+ answer_youtube_video_question_tool,
175
+ read_excel_file_tool
176
+ ]
177
+ '''llm = ChatGoogleGenerativeAI(
178
+ model="gemini-2.0-flash",
179
+ temperature=0.2,
180
+ api_key=os.getenv("GEMINI_API_KEY")
181
+ )'''
182
+ llm = ChatOpenAI(model="o3-mini", openai_api_key=os.getenv("OPENAI_API_KEY"))
183
+ self.llm_with_tools = llm.bind_tools(tools)
184
+ builder = StateGraph(AgentState)
185
+
186
+ # Define nodes: these do the work
187
+ builder.add_node("assistant", self.assistant)
188
+ builder.add_node("tools", ToolNode(tools))
189
+
190
+ # Define edges: these determine how the control flow moves
191
+ builder.add_edge(START, "assistant")
192
+ builder.add_conditional_edges(
193
+ "assistant",
194
+ # If the latest message requires a tool, route to tools
195
+ # Otherwise, provide a direct response
196
+ tools_condition,
197
+ )
198
+ builder.add_edge("tools", "assistant")
199
+ self.agent = builder.compile()
200
 
201
+ print("BasicAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def assistant(self, state: AgentState):
204
+ # System message
205
+ textual_description_of_tools="""
206
+ tavily_search(query: str) -> str:
207
+ Search the web for a given query.
208
+ Args:
209
+ query: Query to search the web for (string).
210
+ Returns:
211
+ A single string containing the information found on the web.
212
+ arxiv_search(query: str) -> str:
213
+ Search Arxiv, that contains scientific papers, for a given query.
214
+ Args:
215
+ query: Query to search Arxiv for (string).
216
+ Returns:
217
+ A single string containing the answer to the question.
218
+ wikipedia_search(query: str) -> str:
219
+ Search Wikipedia for a given query.
220
+ Args:
221
+ query: Query to search Wikipedia for (string).
222
+ Returns:
223
+ A single string containing the answer to the question.
224
+ transcribe_audio(file_path: str) -> str:
225
+ Transcribe the audio from an audio path.
226
+ Args:
227
+ file_path: File path of the audio file (string).
228
+ Returns:
229
+ A single string containing the transcribed text from the audio.
230
+
231
+ answer_python_code(file_path: str, question: str) -> str:
232
+ Answer the question based on the python code.
233
+ Args:
234
+ file_path: File path of the python file (string).
235
+ question: Question to answer (string).
236
+ Returns:
237
+ A single string containing the answer to the question.
238
+
239
+ answer_image_question(file_path: str, question: str) -> str:
240
+ Answer the question based on the image.
241
+ Args:
242
+ file_path: File path of the image (string).
243
+ question: Question to answer (string).
244
+ Returns:
245
+ A single string containing the answer to the question.
246
+
247
+ download_youtube_video(youtube_video_url: str) -> str:
248
+ Download the Youtube video into a local file based on the URL
249
+ Args:
250
+ youtube_video_url: A youtube video url (string).
251
+ Returns:
252
+ A single string containing the file path of the downloaded youtube video.
253
+ answer_youtube_video_question(file_path: str, question: str) -> str:
254
+ Answer the question based on file path of the downloaded youtube video
255
+ Args:
256
+ file_path: File path of the downloaded youtube video (string).
257
+ question: Question to answer (string).
258
+ Returns:
259
+ A single string containing the answer to the question.
260
+
261
+ read_excel_file(file_path: str) -> str:
262
+ Read the excel file.
263
+ Args:
264
+ file_path: File path of the excel file (string).
265
+ Returns:
266
+ A markdown formatted string containing the contents of the excel file.
267
  """
268
+ file_path=state["file_path"]
269
+ prompt = f"""
270
+ You are a helpful assistant that can analyse images, videos, excel files and Python scripts and run computations with provided tools:
271
+ {textual_description_of_tools}
272
+ You have access to the file path of the attached file in case it's informed. Currently the file path is: {file_path}
273
+ Be direct and specific. GAIA benchmark requires exact matching answers.
274
+ For example, if asked "What is the capital of France?", respond simply with "Paris".
275
+ Do not include any prefixes, dots, enumerations, explanations, or quotation marks.
276
+ Do not include any additional text or formatting.
277
+ If you are required a number, return a number, not the items.
278
  """
279
+ sys_msg = SystemMessage(content=prompt)
280
+
281
+ return {
282
+ "messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"], config={"configurable": {"file_path": state["file_path"]}})],
283
+ "file_path": state["file_path"]
284
+ }
285
+ '''return {
286
+ "messages": [self.llm_with_tools.invoke(
287
+ state["messages"],
288
+ config={"configurable": {"file_path": state["file_path"]}} # Aquí pasas el task_id
289
+ )],
290
+ "file_path": state["file_path"]
291
+ }'''
292
+
293
+ def __call__(self, question: str, task_id: str, file_name: str) -> str:
294
+ print(f"######################### Agent received question (first 50 chars): {question[:50]}... with file_name: {file_name}")
295
+
296
+ # Get the file path
297
+ tmp_file_path = None
298
+ if file_name is not None and file_name != "":
299
+ file_content = retrieve_task_file(task_id)
300
+ if file_content is not None:
301
+ print(f"Saving file {file_name} to tmp folder")
302
+ tmp_file_path = f"tmp/{file_name}"
303
+ with open(tmp_file_path, "wb") as f:
304
+ f.write(file_content)
305
+ # Show the file path
306
+ print(f"File path: {tmp_file_path}")
307
+
308
+ messages = self.agent.invoke({"messages": [HumanMessage(question)], "file_path": tmp_file_path})
309
+ # Show the messages
310
+ for m in messages['messages']:
311
+ m.pretty_print()
312
+ answer = messages["messages"][-1].content
313
+ answer = format_gaia_answer(answer)
314
+ print(f"######################### Agent returning answer: {answer}\n")
315
+ # Delete the file
316
+ if tmp_file_path is not None:
317
+ os.remove(tmp_file_path)
318
+ return answer
319
 
320
 
321