codelion commited on
Commit
830c9fb
·
verified ·
1 Parent(s): 3f2c22a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -25
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import os
2
  import json
 
 
3
  import gradio as gr
4
  import cv2
 
5
  from google import genai
6
  from google.genai import types
7
  from google.genai.types import Part
@@ -16,7 +19,7 @@ if not GOOGLE_API_KEY:
16
  client = genai.Client(api_key=GOOGLE_API_KEY)
17
 
18
  # Use the Gemini 2.0 Flash model.
19
- MODEL_NAME = "gemini-2.0-flash"
20
 
21
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
22
  def call_gemini(video_url: str, prompt: str) -> str:
@@ -46,10 +49,40 @@ def hhmmss_to_seconds(time_str: str) -> float:
46
  else:
47
  return parts[0]
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
50
  """
51
  Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions,
52
- then extract those frames from the video using OpenCV.
53
 
54
  Returns a list of tuples: (image_array, caption)
55
  """
@@ -73,27 +106,33 @@ def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
73
  key_frames = []
74
 
75
  extracted_frames = []
76
- cap = cv2.VideoCapture(video_url)
77
- if not cap.isOpened():
78
- print("Error: Could not open video.")
79
- return extracted_frames
80
-
81
- for frame_obj in key_frames:
82
- ts = frame_obj.get("timestamp")
83
- description = frame_obj.get("description", "")
84
- try:
85
- seconds = hhmmss_to_seconds(ts)
86
- except Exception:
87
- continue
88
- # Set video position (in milliseconds)
89
- cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
90
- ret, frame = cap.read()
91
- if ret:
92
- # Convert BGR to RGB
93
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
94
- caption = f"{ts}: {description}"
95
- extracted_frames.append((frame_rgb, caption))
96
- cap.release()
 
 
 
 
 
 
97
  return extracted_frames
98
 
99
  def analyze_video(video_url: str, user_query: str) -> (str, list):
@@ -157,11 +196,10 @@ def gradio_interface(video_url: str, user_query: str) -> (str, list):
157
  return "Please provide a valid video URL.", []
158
  return analyze_video(video_url, user_query)
159
 
160
- # Define the Gradio interface with two inputs and two outputs.
161
  iface = gr.Interface(
162
  fn=gradio_interface,
163
  inputs=[
164
- gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube direct link or video file URL)"),
165
  gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
166
  ],
167
  outputs=[
 
1
  import os
2
  import json
3
+ import tempfile
4
+ import requests
5
  import gradio as gr
6
  import cv2
7
+ from pytube import YouTube
8
  from google import genai
9
  from google.genai import types
10
  from google.genai.types import Part
 
19
  client = genai.Client(api_key=GOOGLE_API_KEY)
20
 
21
  # Use the Gemini 2.0 Flash model.
22
+ MODEL_NAME = "gemini-2.0-flash-001"
23
 
24
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
25
  def call_gemini(video_url: str, prompt: str) -> str:
 
49
  else:
50
  return parts[0]
51
 
52
+ def download_video(video_url: str) -> str:
53
+ """
54
+ Download the video from a URL (either YouTube or direct link) and return the local file path.
55
+ """
56
+ local_file = None
57
+ if "youtube.com" in video_url or "youtu.be" in video_url:
58
+ yt = YouTube(video_url)
59
+ stream = yt.streams.filter(file_extension="mp4", progressive=True).first()
60
+ if stream is None:
61
+ raise ValueError("No suitable mp4 stream found on YouTube.")
62
+ temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
63
+ stream.stream_to_buffer(temp_file)
64
+ temp_file.flush()
65
+ local_file = temp_file.name
66
+ temp_file.close()
67
+ else:
68
+ # Assume it's a direct link to a video file, download using requests.
69
+ response = requests.get(video_url, stream=True)
70
+ if response.status_code == 200:
71
+ temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
72
+ for chunk in response.iter_content(chunk_size=8192):
73
+ if chunk:
74
+ temp_file.write(chunk)
75
+ temp_file.flush()
76
+ local_file = temp_file.name
77
+ temp_file.close()
78
+ else:
79
+ raise ValueError("Failed to download video, status code: " + str(response.status_code))
80
+ return local_file
81
+
82
  def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
83
  """
84
  Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions,
85
+ then extract those frames from the downloaded video file using OpenCV.
86
 
87
  Returns a list of tuples: (image_array, caption)
88
  """
 
106
  key_frames = []
107
 
108
  extracted_frames = []
109
+ local_path = None
110
+ try:
111
+ local_path = download_video(video_url)
112
+ cap = cv2.VideoCapture(local_path)
113
+ if not cap.isOpened():
114
+ print("Error: Could not open video from local file.")
115
+ return extracted_frames
116
+
117
+ for frame_obj in key_frames:
118
+ ts = frame_obj.get("timestamp")
119
+ description = frame_obj.get("description", "")
120
+ try:
121
+ seconds = hhmmss_to_seconds(ts)
122
+ except Exception:
123
+ continue
124
+ # Set video position (in milliseconds)
125
+ cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
126
+ ret, frame = cap.read()
127
+ if ret:
128
+ # Convert BGR to RGB
129
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
130
+ caption = f"{ts}: {description}"
131
+ extracted_frames.append((frame_rgb, caption))
132
+ cap.release()
133
+ finally:
134
+ if local_path and os.path.exists(local_path):
135
+ os.remove(local_path)
136
  return extracted_frames
137
 
138
  def analyze_video(video_url: str, user_query: str) -> (str, list):
 
196
  return "Please provide a valid video URL.", []
197
  return analyze_video(video_url, user_query)
198
 
 
199
  iface = gr.Interface(
200
  fn=gradio_interface,
201
  inputs=[
202
+ gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube link or direct video file URL)"),
203
  gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
204
  ],
205
  outputs=[