codelion commited on
Commit
7c2c622
·
verified ·
1 Parent(s): 0f96bc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -41
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  import gradio as gr
3
- import matplotlib.pyplot as plt
 
4
  from collections import Counter
5
  from google import genai
6
  from google.genai import types
@@ -16,13 +18,13 @@ if not GOOGLE_API_KEY:
16
  client = genai.Client(api_key=GOOGLE_API_KEY)
17
 
18
  # Use the Gemini 2.0 Flash model.
19
- MODEL_NAME = "gemini-2.0-flash-001"
20
 
21
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
22
  def call_gemini(video_url: str, prompt: str) -> str:
23
  """
24
  Call the Gemini model with the provided video URL and prompt.
25
- The video URL is passed as a URI part with MIME type "video/webm".
26
  """
27
  response = client.models.generate_content(
28
  model=MODEL_NAME,
@@ -33,48 +35,100 @@ def call_gemini(video_url: str, prompt: str) -> str:
33
  )
34
  return response.text
35
 
36
- def generate_chart(analysis_text: str) -> plt.Figure:
37
  """
38
- Create a simple bar chart based on the frequency of selected keywords in the analysis.
39
  """
40
- # Define keywords of interest
41
- keywords = ["suspicious", "anomaly", "incident", "alert", "object", "movement"]
42
- # Lowercase the analysis text and split into words
43
- words = analysis_text.lower().split()
44
- # Count occurrences for each keyword
45
- counter = Counter({kw: words.count(kw) for kw in keywords})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Create a bar chart using matplotlib
48
- fig, ax = plt.subplots(figsize=(6, 4))
49
- ax.bar(counter.keys(), counter.values(), color="skyblue")
50
- ax.set_title("Keyword Frequency in Analysis")
51
- ax.set_ylabel("Count")
52
- ax.set_xlabel("Keyword")
53
- plt.tight_layout()
54
- return fig
55
 
56
- def analyze_video(video_url: str, user_query: str) -> (str, plt.Figure):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  """
58
- Perform iterative (agentic) video analysis.
59
- The analysis is refined over several iterations, incorporating the user query if provided.
60
- Returns a Markdown report and a matplotlib chart.
 
 
 
 
61
  """
62
  analysis = ""
63
  num_iterations = 3
64
 
65
  for i in range(num_iterations):
66
- base_prompt = "You are a video analysis agent focusing on security and surveillance. Provide a detailed summary of the video, highlighting key events, suspicious activities, or anomalies."
 
 
 
67
  if user_query:
68
  base_prompt += f" Also, focus on the following query: {user_query}"
69
 
70
  if i == 0:
71
  prompt = base_prompt
72
  else:
73
- prompt = (f"Based on the previous analysis: \"{analysis}\". "
74
- "Provide further elaboration and refined insights, focusing on potential security threats, anomalous events, "
75
- "and details that would help a security team understand the situation better. ")
 
 
76
  if user_query:
77
- prompt += f"Remember to focus on: {user_query}"
78
 
79
  try:
80
  analysis = call_gemini(video_url, prompt)
@@ -82,39 +136,45 @@ def analyze_video(video_url: str, user_query: str) -> (str, plt.Figure):
82
  analysis += f"\n[Error during iteration {i+1}: {e}]"
83
  break
84
 
85
- # Create a Markdown report (adding headings and bullet points if desired)
86
  markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{analysis}\n"
87
-
88
- # Generate a chart visualization based on the analysis text.
89
- chart_fig = generate_chart(analysis)
90
- return markdown_report, chart_fig
91
 
92
- def gradio_interface(video_url: str, user_query: str) -> (str, any):
 
 
 
 
 
 
 
 
 
 
 
93
  """
94
- Gradio interface function that takes a video URL and an optional query,
95
- then returns a Markdown report and a visualization chart.
96
  """
97
  if not video_url:
98
- return "Please provide a valid video URL.", None
99
  return analyze_video(video_url, user_query)
100
 
101
  # Define the Gradio interface with two inputs and two outputs.
102
  iface = gr.Interface(
103
  fn=gradio_interface,
104
  inputs=[
105
- gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube link)"),
106
  gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
107
  ],
108
  outputs=[
109
  gr.Markdown(label="Security & Surveillance Analysis Report"),
110
- gr.Plot(label="Visualization: Keyword Frequency")
111
  ],
112
  title="AI Video Analysis and Summariser Agent",
113
  description=(
114
  "This agentic video analysis tool uses Google's Gemini 2.0 Flash model via AI Studio "
115
  "to iteratively analyze a video for security and surveillance insights. Provide a video URL and, optionally, "
116
- "a query to guide the analysis. The tool returns a detailed Markdown report along with a bar chart visualization "
117
- "of keyword frequency."
118
  )
119
  )
120
 
 
1
  import os
2
+ import json
3
  import gradio as gr
4
+ import cv2
5
+ import matplotlib.pyplot as plt # imported for compatibility if needed later
6
  from collections import Counter
7
  from google import genai
8
  from google.genai import types
 
18
  client = genai.Client(api_key=GOOGLE_API_KEY)
19
 
20
  # Use the Gemini 2.0 Flash model.
21
+ MODEL_NAME = "gemini-2.0-flash"
22
 
23
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
24
  def call_gemini(video_url: str, prompt: str) -> str:
25
  """
26
  Call the Gemini model with the provided video URL and prompt.
27
+ The video is passed as a URI part with MIME type "video/webm".
28
  """
29
  response = client.models.generate_content(
30
  model=MODEL_NAME,
 
35
  )
36
  return response.text
37
 
38
+ def hhmmss_to_seconds(time_str: str) -> float:
39
  """
40
+ Convert a HH:MM:SS formatted string into seconds.
41
  """
42
+ parts = time_str.strip().split(":")
43
+ parts = [float(p) for p in parts]
44
+ if len(parts) == 3:
45
+ return parts[0]*3600 + parts[1]*60 + parts[2]
46
+ elif len(parts) == 2:
47
+ return parts[0]*60 + parts[1]
48
+ else:
49
+ return parts[0]
50
+
51
+ def get_key_frames(video_url: str, analysis: str, user_query: str) -> list:
52
+ """
53
+ Prompt Gemini to return key frame timestamps (in HH:MM:SS) with descriptions,
54
+ then extract those frames from the video using OpenCV.
55
+
56
+ Returns a list of tuples: (image_array, caption)
57
+ """
58
+ prompt = (
59
+ "Based on the following video analysis, identify key frames that best illustrate "
60
+ "the important events or anomalies. Return a JSON array where each element is an object "
61
+ "with two keys: 'timestamp' (in HH:MM:SS format) and 'description' (a brief explanation of why "
62
+ "this frame is important)."
63
+ )
64
+ prompt += f" Video Analysis: {analysis}"
65
+ if user_query:
66
+ prompt += f" Additional focus: {user_query}"
67
+
68
+ try:
69
+ key_frames_response = call_gemini(video_url, prompt)
70
+ # Attempt to parse the output as JSON.
71
+ key_frames = json.loads(key_frames_response)
72
+ if not isinstance(key_frames, list):
73
+ key_frames = []
74
+ except Exception as e:
75
+ key_frames = []
76
 
77
+ extracted_frames = []
78
+ cap = cv2.VideoCapture(video_url)
79
+ if not cap.isOpened():
80
+ print("Error: Could not open video.")
81
+ return extracted_frames
 
 
 
82
 
83
+ for frame_obj in key_frames:
84
+ ts = frame_obj.get("timestamp")
85
+ description = frame_obj.get("description", "")
86
+ try:
87
+ seconds = hhmmss_to_seconds(ts)
88
+ except Exception:
89
+ continue
90
+ # Set video position (in milliseconds)
91
+ cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
92
+ ret, frame = cap.read()
93
+ if ret:
94
+ # Convert BGR to RGB
95
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
96
+ caption = f"{ts}: {description}"
97
+ extracted_frames.append((frame_rgb, caption))
98
+ cap.release()
99
+ return extracted_frames
100
+
101
+ def analyze_video(video_url: str, user_query: str) -> (str, list):
102
  """
103
+ Perform iterative, agentic video analysis.
104
+ First, refine the video analysis over several iterations.
105
+ Then, prompt the model to identify key frames.
106
+
107
+ Returns:
108
+ - A Markdown report as a string.
109
+ - A gallery list of key frames (each as a tuple of (image, caption)).
110
  """
111
  analysis = ""
112
  num_iterations = 3
113
 
114
  for i in range(num_iterations):
115
+ base_prompt = (
116
+ "You are a video analysis agent focusing on security and surveillance. "
117
+ "Provide a detailed summary of the video, highlighting key events, suspicious activities, or anomalies."
118
+ )
119
  if user_query:
120
  base_prompt += f" Also, focus on the following query: {user_query}"
121
 
122
  if i == 0:
123
  prompt = base_prompt
124
  else:
125
+ prompt = (
126
+ f"Based on the previous analysis: \"{analysis}\". "
127
+ "Provide further elaboration and refined insights, focusing on potential security threats, anomalous events, "
128
+ "and details that would help a security team understand the situation better."
129
+ )
130
  if user_query:
131
+ prompt += f" Remember to focus on: {user_query}"
132
 
133
  try:
134
  analysis = call_gemini(video_url, prompt)
 
136
  analysis += f"\n[Error during iteration {i+1}: {e}]"
137
  break
138
 
139
+ # Create a Markdown report
140
  markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{analysis}\n"
 
 
 
 
141
 
142
+ # Get key frames based on the analysis and optional query.
143
+ key_frames_gallery = get_key_frames(video_url, analysis, user_query)
144
+ if not key_frames_gallery:
145
+ markdown_report += "\n*No key frames were extracted.*\n"
146
+ else:
147
+ markdown_report += "\n**Key Frames Extracted:**\n"
148
+ for idx, (img, caption) in enumerate(key_frames_gallery, start=1):
149
+ markdown_report += f"- **Frame {idx}:** {caption}\n"
150
+
151
+ return markdown_report, key_frames_gallery
152
+
153
+ def gradio_interface(video_url: str, user_query: str) -> (str, list):
154
  """
155
+ Gradio interface function that accepts a video URL and an optional query,
156
+ then returns a Markdown report and a gallery of key frame images with captions.
157
  """
158
  if not video_url:
159
+ return "Please provide a valid video URL.", []
160
  return analyze_video(video_url, user_query)
161
 
162
  # Define the Gradio interface with two inputs and two outputs.
163
  iface = gr.Interface(
164
  fn=gradio_interface,
165
  inputs=[
166
+ gr.Textbox(label="Video URL (publicly accessible, e.g., YouTube direct link or video file URL)"),
167
  gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis", placeholder="e.g., focus on unusual movements near the entrance")
168
  ],
169
  outputs=[
170
  gr.Markdown(label="Security & Surveillance Analysis Report"),
171
+ gr.Gallery(label="Extracted Key Frames").style(grid=[2], height="auto")
172
  ],
173
  title="AI Video Analysis and Summariser Agent",
174
  description=(
175
  "This agentic video analysis tool uses Google's Gemini 2.0 Flash model via AI Studio "
176
  "to iteratively analyze a video for security and surveillance insights. Provide a video URL and, optionally, "
177
+ "a query to guide the analysis. The tool returns a detailed Markdown report along with a gallery of key frame images."
 
178
  )
179
  )
180