codelion commited on
Commit
001b623
·
verified ·
1 Parent(s): 0ef8445

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -81
app.py CHANGED
@@ -1,14 +1,11 @@
1
  import os
2
  import json
3
- import tempfile
4
- import requests
5
  import gradio as gr
6
  import cv2
7
  from google import genai
8
  from google.genai import types
9
  from google.genai.types import Part
10
  from tenacity import retry, stop_after_attempt, wait_random_exponential
11
- import yt_dlp # Use yt-dlp for robust YouTube downloading
12
 
13
  # Retrieve API key from environment variables.
14
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
@@ -22,15 +19,17 @@ client = genai.Client(api_key=GOOGLE_API_KEY)
22
  MODEL_NAME = "gemini-2.0-flash-001"
23
 
24
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
25
- def call_gemini(video_url: str, prompt: str) -> str:
26
  """
27
- Call the Gemini model with the provided video URL and prompt.
28
- The video is passed as a URI part with MIME type "video/webm".
29
  """
 
 
30
  response = client.models.generate_content(
31
  model=MODEL_NAME,
32
  contents=[
33
- Part.from_uri(file_uri=video_url, mime_type="video/webm"),
34
  prompt,
35
  ],
36
  )
@@ -49,42 +48,10 @@ def hhmmss_to_seconds(time_str: str) -> float:
49
  else:
50
  return parts[0]
51
 
52
- def download_video(video_url: str) -> str:
53
- """
54
- Download the video from a URL. If it's a YouTube URL, use yt-dlp;
55
- otherwise, use requests for direct links.
56
- Returns the local file path.
57
- """
58
- local_file = None
59
- if "youtube.com" in video_url or "youtu.be" in video_url:
60
- ydl_opts = {
61
- 'format': 'mp4',
62
- 'outtmpl': '%(id)s.%(ext)s',
63
- 'noplaylist': True,
64
- 'quiet': True,
65
- }
66
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
67
- info = ydl.extract_info(video_url, download=True)
68
- local_file = ydl.prepare_filename(info)
69
- else:
70
- # Assume it's a direct link to a video file.
71
- response = requests.get(video_url, stream=True)
72
- if response.status_code == 200:
73
- temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
74
- for chunk in response.iter_content(chunk_size=8192):
75
- if chunk:
76
- temp_file.write(chunk)
77
- temp_file.flush()
78
- local_file = temp_file.name
79
- temp_file.close()
80
- else:
81
- raise ValueError("Failed to download video, status code: " + str(response.status_code))
82
- return local_file
83
-
84
- def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
85
  """
86
  Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions,
87
- then extract those frames from the downloaded video file using OpenCV.
88
 
89
  Returns a list of tuples: (image_array, caption)
90
  """
@@ -99,7 +66,7 @@ def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
99
  prompt += f" Additional focus: {user_query}"
100
 
101
  try:
102
- key_frames_response = call_gemini(video_url, prompt)
103
  # Attempt to parse the output as JSON.
104
  key_frames = json.loads(key_frames_response)
105
  if not isinstance(key_frames, list):
@@ -108,38 +75,32 @@ def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
108
  key_frames = []
109
 
110
  extracted_frames = []
111
- local_path = None
112
- try:
113
- local_path = download_video(video_url)
114
- cap = cv2.VideoCapture(local_path)
115
- if not cap.isOpened():
116
- print("Error: Could not open video from local file.")
117
- return extracted_frames
118
-
119
- for frame_obj in key_frames:
120
- ts = frame_obj.get("timestamp")
121
- description = frame_obj.get("description", "")
122
- try:
123
- seconds = hhmmss_to_seconds(ts)
124
- except Exception:
125
- continue
126
- # Set video position (in milliseconds)
127
- cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
128
- ret, frame = cap.read()
129
- if ret:
130
- # Convert BGR to RGB
131
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
132
- caption = f"{ts}: {description}"
133
- extracted_frames.append((frame_rgb, caption))
134
- cap.release()
135
- finally:
136
- if local_path and os.path.exists(local_path):
137
- os.remove(local_path)
138
  return extracted_frames
139
 
140
- def analyze_video(video_url: str, user_query: str) -> (str, list):
141
  """
142
- Perform iterative, agentic video analysis.
143
  First, refine the video analysis over several iterations.
144
  Then, prompt the model to identify key frames.
145
 
@@ -170,7 +131,7 @@ def analyze_video(video_url: str, user_query: str) -> (str, list):
170
  prompt += f" Remember to focus on: {user_query}"
171
 
172
  try:
173
- analysis = call_gemini(video_url, prompt)
174
  except Exception as e:
175
  analysis += f"\n[Error during iteration {i+1}: {e}]"
176
  break
@@ -179,7 +140,7 @@ def analyze_video(video_url: str, user_query: str) -> (str, list):
179
  markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{analysis}\n"
180
 
181
  # Get key frames based on the analysis and optional query.
182
- key_frames_gallery = get_key_frames(video_url, analysis, user_query)
183
  if not key_frames_gallery:
184
  markdown_report += "\n*No key frames were extracted.*\n"
185
  else:
@@ -189,19 +150,19 @@ def analyze_video(video_url: str, user_query: str) -> (str, list):
189
 
190
  return markdown_report, key_frames_gallery
191
 
192
- def gradio_interface(video_url: str, user_query: str) -> (str, list):
193
  """
194
- Gradio interface function that accepts a video URL and an optional query,
195
  then returns a Markdown report and a gallery of key frame images with captions.
196
  """
197
- if not video_url:
198
- return "Please provide a valid video URL.", []
199
- return analyze_video(video_url, user_query)
200
 
201
  iface = gr.Interface(
202
  fn=gradio_interface,
203
  inputs=[
204
- gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube link or direct video file URL)"),
205
  gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
206
  ],
207
  outputs=[
@@ -211,8 +172,9 @@ iface = gr.Interface(
211
  title="AI Video Analysis and Summariser Agent",
212
  description=(
213
  "This agentic video analysis tool uses Google's Gemini 2.0 Flash model via AI Studio "
214
- "to iteratively analyze a video for security and surveillance insights. Provide a video URL and, optionally, "
215
- "a query to guide the analysis. The tool returns a detailed Markdown report along with a gallery of key frame images."
 
216
  )
217
  )
218
 
 
1
  import os
2
  import json
 
 
3
  import gradio as gr
4
  import cv2
5
  from google import genai
6
  from google.genai import types
7
  from google.genai.types import Part
8
  from tenacity import retry, stop_after_attempt, wait_random_exponential
 
9
 
10
  # Retrieve API key from environment variables.
11
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
 
19
  MODEL_NAME = "gemini-2.0-flash-001"
20
 
21
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
22
+ def call_gemini(video_file: str, prompt: str) -> str:
23
  """
24
+ Call the Gemini model with the provided video file and prompt.
25
+ The video file is read as bytes and passed with MIME type "video/mp4".
26
  """
27
+ with open(video_file, "rb") as f:
28
+ file_bytes = f.read()
29
  response = client.models.generate_content(
30
  model=MODEL_NAME,
31
  contents=[
32
+ Part(file_data=file_bytes, mime_type="video/mp4"),
33
  prompt,
34
  ],
35
  )
 
48
  else:
49
  return parts[0]
50
 
51
+ def get_key_frames(video_file: str, analysis: str, user_query: str) -> list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
  Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions,
54
+ then extract those frames from the uploaded video file using OpenCV.
55
 
56
  Returns a list of tuples: (image_array, caption)
57
  """
 
66
  prompt += f" Additional focus: {user_query}"
67
 
68
  try:
69
+ key_frames_response = call_gemini(video_file, prompt)
70
  # Attempt to parse the output as JSON.
71
  key_frames = json.loads(key_frames_response)
72
  if not isinstance(key_frames, list):
 
75
  key_frames = []
76
 
77
  extracted_frames = []
78
+ cap = cv2.VideoCapture(video_file)
79
+ if not cap.isOpened():
80
+ print("Error: Could not open the uploaded video file.")
81
+ return extracted_frames
82
+
83
+ for frame_obj in key_frames:
84
+ ts = frame_obj.get("timestamp")
85
+ description = frame_obj.get("description", "")
86
+ try:
87
+ seconds = hhmmss_to_seconds(ts)
88
+ except Exception:
89
+ continue
90
+ # Set video position (in milliseconds)
91
+ cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
92
+ ret, frame = cap.read()
93
+ if ret:
94
+ # Convert BGR to RGB for proper display
95
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
96
+ caption = f"{ts}: {description}"
97
+ extracted_frames.append((frame_rgb, caption))
98
+ cap.release()
 
 
 
 
 
 
99
  return extracted_frames
100
 
101
+ def analyze_video(video_file: str, user_query: str) -> (str, list):
102
  """
103
+ Perform iterative, agentic video analysis on the uploaded file.
104
  First, refine the video analysis over several iterations.
105
  Then, prompt the model to identify key frames.
106
 
 
131
  prompt += f" Remember to focus on: {user_query}"
132
 
133
  try:
134
+ analysis = call_gemini(video_file, prompt)
135
  except Exception as e:
136
  analysis += f"\n[Error during iteration {i+1}: {e}]"
137
  break
 
140
  markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{analysis}\n"
141
 
142
  # Get key frames based on the analysis and optional query.
143
+ key_frames_gallery = get_key_frames(video_file, analysis, user_query)
144
  if not key_frames_gallery:
145
  markdown_report += "\n*No key frames were extracted.*\n"
146
  else:
 
150
 
151
  return markdown_report, key_frames_gallery
152
 
153
+ def gradio_interface(video_file, user_query: str) -> (str, list):
154
  """
155
+ Gradio interface function that accepts an uploaded video file and an optional query,
156
  then returns a Markdown report and a gallery of key frame images with captions.
157
  """
158
+ if not video_file:
159
+ return "Please upload a valid video file.", []
160
+ return analyze_video(video_file, user_query)
161
 
162
  iface = gr.Interface(
163
  fn=gradio_interface,
164
  inputs=[
165
+ gr.Video(label="Upload Video File", source="upload", type="filepath"),
166
  gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
167
  ],
168
  outputs=[
 
172
  title="AI Video Analysis and Summariser Agent",
173
  description=(
174
  "This agentic video analysis tool uses Google's Gemini 2.0 Flash model via AI Studio "
175
+ "to iteratively analyze an uploaded video for security and surveillance insights. "
176
+ "Provide a video file and, optionally, a query to guide the analysis. The tool returns a detailed "
177
+ "Markdown report along with a gallery of key frame images."
178
  )
179
  )
180