File size: 8,411 Bytes
f8aaa9d
7c2c622
830c9fb
 
f8aaa9d
7c2c622
830c9fb
f8aaa9d
 
 
 
 
0f96bc2
f8aaa9d
 
 
 
0f96bc2
f8aaa9d
 
0f96bc2
830c9fb
f8aaa9d
 
 
 
 
7c2c622
f8aaa9d
 
 
 
 
 
 
 
 
 
7c2c622
f8aaa9d
7c2c622
0f96bc2
7c2c622
 
 
 
 
 
 
 
 
830c9fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c2c622
 
 
830c9fb
7c2c622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f96bc2
7c2c622
830c9fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c2c622
 
 
0f96bc2
7c2c622
 
 
 
 
 
 
f8aaa9d
 
 
 
 
7c2c622
 
 
 
0f96bc2
 
 
f8aaa9d
0f96bc2
f8aaa9d
7c2c622
 
 
 
 
0f96bc2
7c2c622
0f96bc2
f8aaa9d
 
 
 
0f96bc2
 
7c2c622
0f96bc2
f8aaa9d
7c2c622
 
 
 
 
 
 
 
 
 
 
 
f8aaa9d
7c2c622
 
f8aaa9d
 
7c2c622
0f96bc2
f8aaa9d
 
 
0f96bc2
830c9fb
0f96bc2
 
 
 
3f2c22a
0f96bc2
f8aaa9d
 
 
0f96bc2
7c2c622
f8aaa9d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import json
import tempfile
import requests
import gradio as gr
import cv2
from pytube import YouTube
from google import genai
from google.genai import types
from google.genai.types import Part
from tenacity import retry, stop_after_attempt, wait_random_exponential

# Retrieve API key from environment variables.
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
    raise ValueError("Please set the GOOGLE_API_KEY environment variable.")

# Initialize the Gemini API client via AI Studio using the API key.
client = genai.Client(api_key=GOOGLE_API_KEY)

# Use the Gemini 2.0 Flash model.
MODEL_NAME = "gemini-2.0-flash-001"

@retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
def call_gemini(video_url: str, prompt: str) -> str:
    """
    Call the Gemini model with the provided video URL and prompt.
    The video is passed as a URI part with MIME type "video/webm".
    """
    response = client.models.generate_content(
        model=MODEL_NAME,
        contents=[
            Part.from_uri(file_uri=video_url, mime_type="video/webm"),
            prompt,
        ],
    )
    return response.text

def hhmmss_to_seconds(time_str: str) -> float:
    """
    Convert a HH:MM:SS formatted string into seconds.
    """
    parts = time_str.strip().split(":")
    parts = [float(p) for p in parts]
    if len(parts) == 3:
        return parts[0]*3600 + parts[1]*60 + parts[2]
    elif len(parts) == 2:
        return parts[0]*60 + parts[1]
    else:
        return parts[0]

def download_video(video_url: str) -> str:
    """
    Download the video from a URL (either YouTube or direct link) and return the local file path.
    """
    local_file = None
    if "youtube.com" in video_url or "youtu.be" in video_url:
        yt = YouTube(video_url)
        stream = yt.streams.filter(file_extension="mp4", progressive=True).first()
        if stream is None:
            raise ValueError("No suitable mp4 stream found on YouTube.")
        temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
        stream.stream_to_buffer(temp_file)
        temp_file.flush()
        local_file = temp_file.name
        temp_file.close()
    else:
        # Assume it's a direct link to a video file, download using requests.
        response = requests.get(video_url, stream=True)
        if response.status_code == 200:
            temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    temp_file.write(chunk)
            temp_file.flush()
            local_file = temp_file.name
            temp_file.close()
        else:
            raise ValueError("Failed to download video, status code: " + str(response.status_code))
    return local_file

def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
    """
    Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions,
    then extract those frames from the downloaded video file using OpenCV.
    
    Returns a list of tuples: (image_array, caption)
    """
    prompt = (
        "Based on the following video analysis, identify key frames that best illustrate "
        "the important events or anomalies. Return a JSON array where each element is an object "
        "with two keys: 'timestamp' (in HH:MM:SS format) and 'description' (a brief explanation of why "
        "this frame is important)."
    )
    prompt += f" Video Analysis: {analysis}"
    if user_query:
        prompt += f" Additional focus: {user_query}"
    
    try:
        key_frames_response = call_gemini(video_url, prompt)
        # Attempt to parse the output as JSON.
        key_frames = json.loads(key_frames_response)
        if not isinstance(key_frames, list):
            key_frames = []
    except Exception as e:
        key_frames = []
    
    extracted_frames = []
    local_path = None
    try:
        local_path = download_video(video_url)
        cap = cv2.VideoCapture(local_path)
        if not cap.isOpened():
            print("Error: Could not open video from local file.")
            return extracted_frames

        for frame_obj in key_frames:
            ts = frame_obj.get("timestamp")
            description = frame_obj.get("description", "")
            try:
                seconds = hhmmss_to_seconds(ts)
            except Exception:
                continue
            # Set video position (in milliseconds)
            cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
            ret, frame = cap.read()
            if ret:
                # Convert BGR to RGB
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                caption = f"{ts}: {description}"
                extracted_frames.append((frame_rgb, caption))
        cap.release()
    finally:
        if local_path and os.path.exists(local_path):
            os.remove(local_path)
    return extracted_frames

def analyze_video(video_url: str, user_query: str) -> (str, list):
    """
    Perform iterative, agentic video analysis.
    First, refine the video analysis over several iterations.
    Then, prompt the model to identify key frames.
    
    Returns:
      - A Markdown report as a string.
      - A gallery list of key frames (each as a tuple of (image, caption)).
    """
    analysis = ""
    num_iterations = 3

    for i in range(num_iterations):
        base_prompt = (
            "You are a video analysis agent focusing on security and surveillance. "
            "Provide a detailed summary of the video, highlighting key events, suspicious activities, or anomalies."
        )
        if user_query:
            base_prompt += f" Also, focus on the following query: {user_query}"
            
        if i == 0:
            prompt = base_prompt
        else:
            prompt = (
                f"Based on the previous analysis: \"{analysis}\". "
                "Provide further elaboration and refined insights, focusing on potential security threats, anomalous events, "
                "and details that would help a security team understand the situation better."
            )
            if user_query:
                prompt += f" Remember to focus on: {user_query}"
                
        try:
            analysis = call_gemini(video_url, prompt)
        except Exception as e:
            analysis += f"\n[Error during iteration {i+1}: {e}]"
            break

    # Create a Markdown report
    markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{analysis}\n"

    # Get key frames based on the analysis and optional query.
    key_frames_gallery = get_key_frames(video_url, analysis, user_query)
    if not key_frames_gallery:
        markdown_report += "\n*No key frames were extracted.*\n"
    else:
        markdown_report += "\n**Key Frames Extracted:**\n"
        for idx, (img, caption) in enumerate(key_frames_gallery, start=1):
            markdown_report += f"- **Frame {idx}:** {caption}\n"

    return markdown_report, key_frames_gallery

def gradio_interface(video_url: str, user_query: str) -> (str, list):
    """
    Gradio interface function that accepts a video URL and an optional query,
    then returns a Markdown report and a gallery of key frame images with captions.
    """
    if not video_url:
        return "Please provide a valid video URL.", []
    return analyze_video(video_url, user_query)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube link or direct video file URL)"),
        gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
    ],
    outputs=[
        gr.Markdown(label="Security & Surveillance Analysis Report"),
        gr.Gallery(label="Extracted Key Frames", columns=2)
    ],
    title="AI Video Analysis and Summariser Agent",
    description=(
        "This agentic video analysis tool uses Google's Gemini 2.0 Flash model via AI Studio "
        "to iteratively analyze a video for security and surveillance insights. Provide a video URL and, optionally, "
        "a query to guide the analysis. The tool returns a detailed Markdown report along with a gallery of key frame images."
    )
)

if __name__ == "__main__":
    iface.launch()