wt002 commited on
Commit
b011db5
·
verified ·
1 Parent(s): 81234d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -18
app.py CHANGED
@@ -1,36 +1,210 @@
1
  import os
2
- from dotenv import load_dotenv
3
- import gradio as gr
4
- import requests
5
-
6
- import os
7
- import inspect
8
  import gradio as gr
9
  import requests
10
  import pandas as pd
11
- from langchain_core.messages import HumanMessage
12
- from agent import build_graph
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- load_dotenv()
15
 
16
  # (Keep Constants as is)
17
  # --- Constants ---
18
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
 
20
 
21
- # --- Basic Agent Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class BasicAgent:
 
24
  def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  print("BasicAgent initialized.")
26
- self.graph = build_graph()
27
-
28
- def __call__(self, question: str) -> str:
29
- print(f"Agent received question (first 50 chars): {question[:50]}...")
30
- messages = [HumanMessage(content=question)]
31
- result = self.graph.invoke({"messages": messages})
32
- answer = result['messages'][-1].content
33
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  def run_and_submit_all( profile: gr.OAuthProfile | None):
 
1
  import os
 
 
 
 
 
 
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
5
+ from smolagents import CodeAgent, OpenAIServerModel, DuckDuckGoSearchTool, VisitWebpageTool, tool, \
6
+ FinalAnswerTool, PythonInterpreterTool, SpeechToTextTool, ToolCallingAgent
7
+ import yaml
8
+ import importlib
9
+ from io import BytesIO
10
+ import tempfile
11
+ import base64
12
+ from youtube_transcript_api import YouTubeTranscriptApi
13
+ from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound, VideoUnavailable
14
+ from urllib.parse import urlparse, parse_qs
15
+ import json
16
+ import whisper
17
+ import re
18
+
19
 
 
20
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
23
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
24
 
25
 
26
+ @tool
27
+ def transcribe_audio_file(file_path: str) -> str:
28
+ """
29
+ Transcribes a local MP3 audio file using Whisper.
30
+ Args:
31
+ file_path: Full path to the .mp3 audio file.
32
+ Returns:
33
+ A JSON-formatted string containing either the transcript or an error message.
34
+ {
35
+ "success": true,
36
+ "transcript": [
37
+ {"start": 0.0, "end": 5.2, "text": "Hello and welcome"},
38
+ ...
39
+ ]
40
+ }
41
+ OR
42
+ {
43
+ "success": false,
44
+ "error": "Reason why transcription failed"
45
+ }
46
+ """
47
+ try:
48
+ if not os.path.exists(file_path):
49
+ return json.dumps({"success": False, "error": "File does not exist."})
50
+
51
+ if not file_path.lower().endswith(".mp3"):
52
+ return json.dumps({"success": False, "error": "Invalid file type. Only MP3 files are supported."})
53
+
54
+ model = whisper.load_model("base") # You can use 'tiny', 'base', 'small', 'medium', or 'large'
55
+ result = model.transcribe(file_path, verbose=False, word_timestamps=False)
56
+
57
+ transcript_data = [
58
+ {
59
+ "start": segment["start"],
60
+ "end": segment["end"],
61
+ "text": segment["text"].strip()
62
+ }
63
+ for segment in result["segments"]
64
+ ]
65
+
66
+ return json.dumps({"success": True, "transcript": transcript_data})
67
+
68
+ except Exception as e:
69
+ return json.dumps({"success": False, "error": str(e)})
70
+
71
+
72
+ @tool
73
+ def get_youtube_transcript(video_url: str) -> str:
74
+ """
75
+ Retrieves the transcript from a YouTube video URL, including timestamps.
76
+ This tool fetches the English transcript for a given YouTube video. Automatically generated subtitles
77
+ are also supported. The result includes each snippet's start time, duration, and text.
78
+ Args:
79
+ video_url: The full URL of the YouTube video (e.g., https://www.youtube.com/watch?v=12345)
80
+ Returns:
81
+ A JSON-formatted string containing either the transcript with timestamps or an error message.
82
+ {
83
+ "success": true,
84
+ "transcript": [
85
+ {"start": 0.0, "duration": 1.54, "text": "Hey there"},
86
+ {"start": 1.54, "duration": 4.16, "text": "how are you"},
87
+ ...
88
+ ]
89
+ }
90
+ OR
91
+ {
92
+ "success": false,
93
+ "error": "Reason why the transcript could not be retrieved"
94
+ }
95
+ """
96
+ try:
97
+ # Extract video ID from URL
98
+ parsed_url = urlparse(video_url)
99
+ query_params = parse_qs(parsed_url.query)
100
+ video_id = query_params.get("v", [None])[0]
101
+
102
+ if not video_id:
103
+ return json.dumps({"success": False, "error": "Invalid YouTube URL. Could not extract video ID."})
104
 
105
+ fetched_transcript = YouTubeTranscriptApi().fetch(video_id)
106
+ transcript_data = [
107
+ {
108
+ "start": snippet.start,
109
+ "duration": snippet.duration,
110
+ "text": snippet.text
111
+ }
112
+ for snippet in fetched_transcript
113
+ ]
114
+
115
+ return json.dumps({"success": True, "transcript": transcript_data})
116
+
117
+ except VideoUnavailable:
118
+ return json.dumps({"success": False, "error": "The video is unavailable."})
119
+ except TranscriptsDisabled:
120
+ return json.dumps({"success": False, "error": "Transcripts are disabled for this video."})
121
+ except NoTranscriptFound:
122
+ return json.dumps({"success": False, "error": "No transcript found for this video."})
123
+ except Exception as e:
124
+ return json.dumps({"success": False, "error": str(e)})
125
+
126
+ # --- Basic Agent Definition ---
127
+ # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
128
  class BasicAgent:
129
+
130
  def __init__(self):
131
+ model = OpenAIServerModel(api_key=os.environ.get("OPENAI_API_KEY"), model_id="gpt-4o")
132
+
133
+ self.code_agent = CodeAgent(
134
+ tools=[PythonInterpreterTool(), DuckDuckGoSearchTool(), VisitWebpageTool(), transcribe_audio_file,
135
+ get_youtube_transcript,
136
+ FinalAnswerTool()],
137
+ model=model,
138
+ max_steps=20,
139
+ name="hf_agent_course_final_assignment_solver",
140
+ prompt_templates=yaml.safe_load(
141
+ importlib.resources.files("prompts").joinpath("code_agent.yaml").read_text()
142
+ )
143
+
144
+ )
145
  print("BasicAgent initialized.")
146
+
147
+ def __call__(self, task_id: str, question: str, file_name: str) -> str:
148
+ if file_name:
149
+ question = self.enrich_question_with_associated_file_details(task_id, question, file_name)
150
+
151
+ final_result = self.code_agent.run(question)
152
+
153
+ # Extract text after "FINAL ANSWER:" (case-insensitive, and trims whitespace)
154
+ match = re.search(r'final answer:\s*(.*)', str(final_result), re.IGNORECASE | re.DOTALL)
155
+ if match:
156
+ return match.group(1).strip()
157
+
158
+ # Fallback in case the pattern is not found
159
+ return str(final_result).strip()
160
+
161
+ def enrich_question_with_associated_file_details(self, task_id:str, question: str, file_name: str) -> str:
162
+ api_url = DEFAULT_API_URL
163
+ get_associated_files_url = f"{api_url}/files/{task_id}"
164
+ response = requests.get(get_associated_files_url, timeout=15)
165
+ response.raise_for_status()
166
+
167
+ if file_name.endswith(".mp3"):
168
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
169
+ tmp_file.write(response.content)
170
+ file_path = tmp_file.name
171
+ return question + "\n\nMentioned .mp3 file local path is: " + file_path
172
+ elif file_name.endswith(".py"):
173
+ file_content = response.text
174
+ return question + "\n\nBelow is mentioned Python file:\n\n```python\n" + file_content + "\n```\n"
175
+ elif file_name.endswith(".xlsx"):
176
+ xlsx_io = BytesIO(response.content)
177
+ df = pd.read_excel(xlsx_io)
178
+ file_content = df.to_csv(index=False)
179
+ return question + "\n\nBelow is mentioned excel file in CSV format:\n\n```csv\n" + file_content + "\n```\n"
180
+ elif file_name.endswith(".png"):
181
+ base64_str = base64.b64encode(response.content).decode('utf-8')
182
+ return question + "\n\nBelow is the .png image in base64 format:\n\n```base64\n" + base64_str + "\n```\n"
183
+
184
+
185
+ def enrich_question_with_associated_file_details(self, task_id:str, question: str, file_name: str) -> str:
186
+ api_url = DEFAULT_API_URL
187
+ get_associated_files_url = f"{api_url}/files/{task_id}"
188
+ response = requests.get(get_associated_files_url, timeout=15)
189
+ response.raise_for_status()
190
+
191
+ if file_name.endswith(".mp3"):
192
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
193
+ tmp_file.write(response.content)
194
+ file_path = tmp_file.name
195
+ return question + "\n\nMentioned .mp3 file local path is: " + file_path
196
+ elif file_name.endswith(".py"):
197
+ file_content = response.text
198
+ return question + "\n\nBelow is mentioned Python file:\n\n```python\n" + file_content + "\n```\n"
199
+ elif file_name.endswith(".xlsx"):
200
+ xlsx_io = BytesIO(response.content)
201
+ df = pd.read_excel(xlsx_io)
202
+ file_content = df.to_csv(index=False)
203
+ return question + "\n\nBelow is mentioned excel file in CSV format:\n\n```csv\n" + file_content + "\n```\n"
204
+ elif file_name.endswith(".png"):
205
+ base64_str = base64.b64encode(response.content).decode('utf-8')
206
+ return question + "\n\nBelow is the .png image in base64 format:\n\n```base64\n" + base64_str + "\n```\n"
207
+
208
 
209
 
210
  def run_and_submit_all( profile: gr.OAuthProfile | None):