videoanalysis / app.py
codelion's picture
Update app.py
d638712 verified
raw
history blame
6.13 kB
import os
import json
import gradio as gr
import cv2
from google import genai
from google.genai.types import Part
from tenacity import retry, stop_after_attempt, wait_random_exponential
# Retrieve API key from environment variables.
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
# Initialize the Gemini API client via AI Studio.
client = genai.Client(api_key=GOOGLE_API_KEY)
# Use the Gemini 2.0 Flash model.
MODEL_NAME = "gemini-2.0-flash-001"
@retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
def call_gemini(video_file: str, prompt: str) -> str:
"""
Call the Gemini model with the provided video file and prompt.
The video file is read as bytes and passed with MIME type "video/mp4",
and the prompt is wrapped as a text part.
"""
with open(video_file, "rb") as f:
file_bytes = f.read()
response = client.models.generate_content(
model=MODEL_NAME,
contents=[
Part(file_data=file_bytes, mime_type="video/mp4"),
Part(text=prompt)
]
)
return response.text
def safe_call_gemini(video_file: str, prompt: str) -> str:
"""
Wrapper for call_gemini that catches exceptions and returns a fallback string.
"""
try:
return call_gemini(video_file, prompt)
except Exception as e:
print("Gemini call failed:", e)
return "No summary available."
def hhmmss_to_seconds(time_str: str) -> float:
"""
Convert a HH:MM:SS formatted string into seconds.
"""
parts = time_str.strip().split(":")
parts = [float(p) for p in parts]
if len(parts) == 3:
return parts[0] * 3600 + parts[1] * 60 + parts[2]
elif len(parts) == 2:
return parts[0] * 60 + parts[1]
else:
return parts[0]
def get_key_frames(video_file: str, summary: str, user_query: str) -> list:
"""
Ask Gemini to output key timestamps and descriptions as plain text.
The prompt instructs the model to output one line per event in the format:
HH:MM:SS - description
We then parse these lines and extract the corresponding frames using OpenCV.
Returns a list of tuples: (image_array, caption)
"""
prompt = (
"List the key timestamps in the video and a brief description of the event at that time. "
"Output one line per event in the following format: HH:MM:SS - description. Do not include any extra text."
)
prompt += f" Video Summary: {summary}"
if user_query:
prompt += f" Focus on: {user_query}"
# Use the safe call to get a response or fallback text.
key_frames_response = safe_call_gemini(video_file, prompt)
lines = key_frames_response.strip().split("\n")
key_frames = []
for line in lines:
if " - " in line:
parts = line.split(" - ", 1)
timestamp = parts[0].strip()
description = parts[1].strip()
key_frames.append({"timestamp": timestamp, "description": description})
extracted_frames = []
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
print("Error: Could not open the uploaded video file.")
return extracted_frames
for frame_obj in key_frames:
ts = frame_obj.get("timestamp")
description = frame_obj.get("description", "")
try:
seconds = hhmmss_to_seconds(ts)
except Exception:
continue
cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
caption = f"{ts}: {description}"
extracted_frames.append((frame_rgb, caption))
cap.release()
return extracted_frames
def analyze_video(video_file: str, user_query: str) -> (str, list):
"""
Perform video analysis on the uploaded file.
First, call Gemini with a simple prompt to get a brief summary.
Then, call Gemini to list key timestamps and descriptions.
Returns:
- A Markdown report summarizing the video.
- A gallery list of key frames (each as a tuple of (image, caption)).
"""
summary_prompt = "Summarize this video."
if user_query:
summary_prompt += f" Also focus on: {user_query}"
summary = safe_call_gemini(video_file, summary_prompt)
markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{summary}\n"
key_frames_gallery = get_key_frames(video_file, summary, user_query)
if not key_frames_gallery:
markdown_report += "\n*No key frames were extracted.*\n"
else:
markdown_report += "\n**Key Frames Extracted:**\n"
for idx, (img, caption) in enumerate(key_frames_gallery, start=1):
markdown_report += f"- **Frame {idx}:** {caption}\n"
return markdown_report, key_frames_gallery
def gradio_interface(video_file, user_query: str) -> (str, list):
"""
Gradio interface function that accepts an uploaded video file and an optional query,
then returns a Markdown report and a gallery of key frame images with captions.
"""
if not video_file:
return "Please upload a valid video file.", []
return analyze_video(video_file, user_query)
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Video(label="Upload Video File"),
gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
],
outputs=[
gr.Markdown(label="Security & Surveillance Analysis Report"),
gr.Gallery(label="Extracted Key Frames", columns=2)
],
title="AI Video Analysis and Summariser Agent",
description=(
"This tool uses Google's Gemini 2.0 Flash model via AI Studio to analyze an uploaded video. "
"It returns a brief summary and extracts key frames based on that summary. "
"Provide a video file and, optionally, a query to guide the analysis."
)
)
if __name__ == "__main__":
iface.launch()