Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import cv2 | |
from google import genai | |
from google.genai.types import Part | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
# Retrieve API key from environment variables. | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY: | |
raise ValueError("Please set the GOOGLE_API_KEY environment variable.") | |
# Initialize the Gemini API client via AI Studio. | |
client = genai.Client(api_key=GOOGLE_API_KEY) | |
# Use the Gemini 2.0 Flash model. | |
MODEL_NAME = "gemini-2.0-flash-001" | |
def call_gemini(video_file: str, prompt: str) -> str: | |
""" | |
Call the Gemini model with the provided video file and prompt. | |
The video file is read as bytes and passed with MIME type "video/mp4", | |
and the prompt is wrapped as a text part. | |
""" | |
with open(video_file, "rb") as f: | |
file_bytes = f.read() | |
response = client.models.generate_content( | |
model=MODEL_NAME, | |
contents=[ | |
Part(file_data=file_bytes, mime_type="video/mp4"), | |
Part(text=prompt) | |
] | |
) | |
return response.text | |
def safe_call_gemini(video_file: str, prompt: str) -> str: | |
""" | |
Wrapper for call_gemini that catches exceptions and returns a fallback string. | |
""" | |
try: | |
return call_gemini(video_file, prompt) | |
except Exception as e: | |
print("Gemini call failed:", e) | |
return "No summary available." | |
def hhmmss_to_seconds(time_str: str) -> float: | |
""" | |
Convert a HH:MM:SS formatted string into seconds. | |
""" | |
parts = time_str.strip().split(":") | |
parts = [float(p) for p in parts] | |
if len(parts) == 3: | |
return parts[0] * 3600 + parts[1] * 60 + parts[2] | |
elif len(parts) == 2: | |
return parts[0] * 60 + parts[1] | |
else: | |
return parts[0] | |
def get_key_frames(video_file: str, summary: str, user_query: str) -> list: | |
""" | |
Ask Gemini to output key timestamps and descriptions as plain text. | |
The prompt instructs the model to output one line per event in the format: | |
HH:MM:SS - description | |
We then parse these lines and extract the corresponding frames using OpenCV. | |
Returns a list of tuples: (image_array, caption) | |
""" | |
prompt = ( | |
"List the key timestamps in the video and a brief description of the event at that time. " | |
"Output one line per event in the following format: HH:MM:SS - description. Do not include any extra text." | |
) | |
prompt += f" Video Summary: {summary}" | |
if user_query: | |
prompt += f" Focus on: {user_query}" | |
# Use the safe call to get a response or fallback text. | |
key_frames_response = safe_call_gemini(video_file, prompt) | |
lines = key_frames_response.strip().split("\n") | |
key_frames = [] | |
for line in lines: | |
if " - " in line: | |
parts = line.split(" - ", 1) | |
timestamp = parts[0].strip() | |
description = parts[1].strip() | |
key_frames.append({"timestamp": timestamp, "description": description}) | |
extracted_frames = [] | |
cap = cv2.VideoCapture(video_file) | |
if not cap.isOpened(): | |
print("Error: Could not open the uploaded video file.") | |
return extracted_frames | |
for frame_obj in key_frames: | |
ts = frame_obj.get("timestamp") | |
description = frame_obj.get("description", "") | |
try: | |
seconds = hhmmss_to_seconds(ts) | |
except Exception: | |
continue | |
cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000) | |
ret, frame = cap.read() | |
if ret: | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
caption = f"{ts}: {description}" | |
extracted_frames.append((frame_rgb, caption)) | |
cap.release() | |
return extracted_frames | |
def analyze_video(video_file: str, user_query: str) -> (str, list): | |
""" | |
Perform video analysis on the uploaded file. | |
First, call Gemini with a simple prompt to get a brief summary. | |
Then, call Gemini to list key timestamps and descriptions. | |
Returns: | |
- A Markdown report summarizing the video. | |
- A gallery list of key frames (each as a tuple of (image, caption)). | |
""" | |
summary_prompt = "Summarize this video." | |
if user_query: | |
summary_prompt += f" Also focus on: {user_query}" | |
summary = safe_call_gemini(video_file, summary_prompt) | |
markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{summary}\n" | |
key_frames_gallery = get_key_frames(video_file, summary, user_query) | |
if not key_frames_gallery: | |
markdown_report += "\n*No key frames were extracted.*\n" | |
else: | |
markdown_report += "\n**Key Frames Extracted:**\n" | |
for idx, (img, caption) in enumerate(key_frames_gallery, start=1): | |
markdown_report += f"- **Frame {idx}:** {caption}\n" | |
return markdown_report, key_frames_gallery | |
def gradio_interface(video_file, user_query: str) -> (str, list): | |
""" | |
Gradio interface function that accepts an uploaded video file and an optional query, | |
then returns a Markdown report and a gallery of key frame images with captions. | |
""" | |
if not video_file: | |
return "Please upload a valid video file.", [] | |
return analyze_video(video_file, user_query) | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Video(label="Upload Video File"), | |
gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance") | |
], | |
outputs=[ | |
gr.Markdown(label="Security & Surveillance Analysis Report"), | |
gr.Gallery(label="Extracted Key Frames", columns=2) | |
], | |
title="AI Video Analysis and Summariser Agent", | |
description=( | |
"This tool uses Google's Gemini 2.0 Flash model via AI Studio to analyze an uploaded video. " | |
"It returns a brief summary and extracts key frames based on that summary. " | |
"Provide a video file and, optionally, a query to guide the analysis." | |
) | |
) | |
if __name__ == "__main__": | |
iface.launch() | |