Spaces:
Sleeping
Sleeping
import os | |
import json | |
import tempfile | |
import requests | |
import gradio as gr | |
import cv2 | |
from pytube import YouTube | |
from google import genai | |
from google.genai import types | |
from google.genai.types import Part | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
# Retrieve API key from environment variables. | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY: | |
raise ValueError("Please set the GOOGLE_API_KEY environment variable.") | |
# Initialize the Gemini API client via AI Studio using the API key. | |
client = genai.Client(api_key=GOOGLE_API_KEY) | |
# Use the Gemini 2.0 Flash model. | |
MODEL_NAME = "gemini-2.0-flash-001" | |
def call_gemini(video_url: str, prompt: str) -> str: | |
""" | |
Call the Gemini model with the provided video URL and prompt. | |
The video is passed as a URI part with MIME type "video/webm". | |
""" | |
response = client.models.generate_content( | |
model=MODEL_NAME, | |
contents=[ | |
Part.from_uri(file_uri=video_url, mime_type="video/webm"), | |
prompt, | |
], | |
) | |
return response.text | |
def hhmmss_to_seconds(time_str: str) -> float: | |
""" | |
Convert a HH:MM:SS formatted string into seconds. | |
""" | |
parts = time_str.strip().split(":") | |
parts = [float(p) for p in parts] | |
if len(parts) == 3: | |
return parts[0]*3600 + parts[1]*60 + parts[2] | |
elif len(parts) == 2: | |
return parts[0]*60 + parts[1] | |
else: | |
return parts[0] | |
def download_video(video_url: str) -> str: | |
""" | |
Download the video from a URL (either YouTube or direct link) and return the local file path. | |
""" | |
local_file = None | |
if "youtube.com" in video_url or "youtu.be" in video_url: | |
yt = YouTube(video_url) | |
stream = yt.streams.filter(file_extension="mp4", progressive=True).first() | |
if stream is None: | |
raise ValueError("No suitable mp4 stream found on YouTube.") | |
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
stream.stream_to_buffer(temp_file) | |
temp_file.flush() | |
local_file = temp_file.name | |
temp_file.close() | |
else: | |
# Assume it's a direct link to a video file, download using requests. | |
response = requests.get(video_url, stream=True) | |
if response.status_code == 200: | |
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
temp_file.write(chunk) | |
temp_file.flush() | |
local_file = temp_file.name | |
temp_file.close() | |
else: | |
raise ValueError("Failed to download video, status code: " + str(response.status_code)) | |
return local_file | |
def get_key_frames(video_url: str, analysis: str, user_query: str) -> list: | |
""" | |
Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions, | |
then extract those frames from the downloaded video file using OpenCV. | |
Returns a list of tuples: (image_array, caption) | |
""" | |
prompt = ( | |
"Based on the following video analysis, identify key frames that best illustrate " | |
"the important events or anomalies. Return a JSON array where each element is an object " | |
"with two keys: 'timestamp' (in HH:MM:SS format) and 'description' (a brief explanation of why " | |
"this frame is important)." | |
) | |
prompt += f" Video Analysis: {analysis}" | |
if user_query: | |
prompt += f" Additional focus: {user_query}" | |
try: | |
key_frames_response = call_gemini(video_url, prompt) | |
# Attempt to parse the output as JSON. | |
key_frames = json.loads(key_frames_response) | |
if not isinstance(key_frames, list): | |
key_frames = [] | |
except Exception as e: | |
key_frames = [] | |
extracted_frames = [] | |
local_path = None | |
try: | |
local_path = download_video(video_url) | |
cap = cv2.VideoCapture(local_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video from local file.") | |
return extracted_frames | |
for frame_obj in key_frames: | |
ts = frame_obj.get("timestamp") | |
description = frame_obj.get("description", "") | |
try: | |
seconds = hhmmss_to_seconds(ts) | |
except Exception: | |
continue | |
# Set video position (in milliseconds) | |
cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000) | |
ret, frame = cap.read() | |
if ret: | |
# Convert BGR to RGB | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
caption = f"{ts}: {description}" | |
extracted_frames.append((frame_rgb, caption)) | |
cap.release() | |
finally: | |
if local_path and os.path.exists(local_path): | |
os.remove(local_path) | |
return extracted_frames | |
def analyze_video(video_url: str, user_query: str) -> (str, list): | |
""" | |
Perform iterative, agentic video analysis. | |
First, refine the video analysis over several iterations. | |
Then, prompt the model to identify key frames. | |
Returns: | |
- A Markdown report as a string. | |
- A gallery list of key frames (each as a tuple of (image, caption)). | |
""" | |
analysis = "" | |
num_iterations = 3 | |
for i in range(num_iterations): | |
base_prompt = ( | |
"You are a video analysis agent focusing on security and surveillance. " | |
"Provide a detailed summary of the video, highlighting key events, suspicious activities, or anomalies." | |
) | |
if user_query: | |
base_prompt += f" Also, focus on the following query: {user_query}" | |
if i == 0: | |
prompt = base_prompt | |
else: | |
prompt = ( | |
f"Based on the previous analysis: \"{analysis}\". " | |
"Provide further elaboration and refined insights, focusing on potential security threats, anomalous events, " | |
"and details that would help a security team understand the situation better." | |
) | |
if user_query: | |
prompt += f" Remember to focus on: {user_query}" | |
try: | |
analysis = call_gemini(video_url, prompt) | |
except Exception as e: | |
analysis += f"\n[Error during iteration {i+1}: {e}]" | |
break | |
# Create a Markdown report | |
markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{analysis}\n" | |
# Get key frames based on the analysis and optional query. | |
key_frames_gallery = get_key_frames(video_url, analysis, user_query) | |
if not key_frames_gallery: | |
markdown_report += "\n*No key frames were extracted.*\n" | |
else: | |
markdown_report += "\n**Key Frames Extracted:**\n" | |
for idx, (img, caption) in enumerate(key_frames_gallery, start=1): | |
markdown_report += f"- **Frame {idx}:** {caption}\n" | |
return markdown_report, key_frames_gallery | |
def gradio_interface(video_url: str, user_query: str) -> (str, list): | |
""" | |
Gradio interface function that accepts a video URL and an optional query, | |
then returns a Markdown report and a gallery of key frame images with captions. | |
""" | |
if not video_url: | |
return "Please provide a valid video URL.", [] | |
return analyze_video(video_url, user_query) | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube link or direct video file URL)"), | |
gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance") | |
], | |
outputs=[ | |
gr.Markdown(label="Security & Surveillance Analysis Report"), | |
gr.Gallery(label="Extracted Key Frames", columns=2) | |
], | |
title="AI Video Analysis and Summariser Agent", | |
description=( | |
"This agentic video analysis tool uses Google's Gemini 2.0 Flash model via AI Studio " | |
"to iteratively analyze a video for security and surveillance insights. Provide a video URL and, optionally, " | |
"a query to guide the analysis. The tool returns a detailed Markdown report along with a gallery of key frame images." | |
) | |
) | |
if __name__ == "__main__": | |
iface.launch() | |