codelion commited on
Commit
cba459f
·
verified ·
1 Parent(s): 03c6357

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import gradio as gr
4
  import cv2
5
  from google import genai
6
- from google.genai.types import Part
7
  from tenacity import retry, stop_after_attempt, wait_random_exponential
8
 
9
  # Retrieve API key from environment variables.
@@ -18,11 +18,11 @@ client = genai.Client(api_key=GOOGLE_API_KEY)
18
  MODEL_NAME = "gemini-2.0-flash-001"
19
 
20
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
21
- def call_gemini(video_file: str, prompt: str) -> str:
22
  """
23
  Call the Gemini model with the provided video file and prompt.
24
- The video file is read as bytes and passed with MIME type "video/mp4",
25
- and the prompt is wrapped as text.
26
  """
27
  with open(video_file, "rb") as f:
28
  file_bytes = f.read()
@@ -32,6 +32,7 @@ def call_gemini(video_file: str, prompt: str) -> str:
32
  Part(file_data=file_bytes, mime_type="video/mp4"),
33
  Part(text=prompt)
34
  ],
 
35
  )
36
  return response.text
37
 
@@ -42,9 +43,9 @@ def hhmmss_to_seconds(time_str: str) -> float:
42
  parts = time_str.strip().split(":")
43
  parts = [float(p) for p in parts]
44
  if len(parts) == 3:
45
- return parts[0]*3600 + parts[1]*60 + parts[2]
46
  elif len(parts) == 2:
47
- return parts[0]*60 + parts[1]
48
  else:
49
  return parts[0]
50
 
@@ -55,18 +56,35 @@ def get_key_frames(video_file: str, analysis: str, user_query: str) -> list:
55
 
56
  Returns a list of tuples: (image_array, caption)
57
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  prompt = (
59
- "Based on the following video analysis, identify key frames that best illustrate "
60
- "the important events or anomalies. Return a JSON array where each element is an object "
61
- "with two keys: 'timestamp' (in HH:MM:SS format) and 'description' (a brief explanation of why "
62
- "this frame is important)."
63
  )
64
  prompt += f" Video Analysis: {analysis}"
65
  if user_query:
66
  prompt += f" Additional focus: {user_query}"
67
 
68
  try:
69
- key_frames_response = call_gemini(video_file, prompt)
70
  key_frames = json.loads(key_frames_response)
71
  if not isinstance(key_frames, list):
72
  key_frames = []
@@ -87,6 +105,7 @@ def get_key_frames(video_file: str, analysis: str, user_query: str) -> list:
87
  seconds = hhmmss_to_seconds(ts)
88
  except Exception:
89
  continue
 
90
  cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
91
  ret, frame = cap.read()
92
  if ret:
 
3
  import gradio as gr
4
  import cv2
5
  from google import genai
6
+ from google.genai.types import Part, GenerateContentConfig
7
  from tenacity import retry, stop_after_attempt, wait_random_exponential
8
 
9
  # Retrieve API key from environment variables.
 
18
  MODEL_NAME = "gemini-2.0-flash-001"
19
 
20
  @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
21
+ def call_gemini(video_file: str, prompt: str, config: GenerateContentConfig = None) -> str:
22
  """
23
  Call the Gemini model with the provided video file and prompt.
24
+ The video file is read as bytes and passed with MIME type "video/mp4".
25
+ Optionally accepts a config (e.g. response_schema) for structured output.
26
  """
27
  with open(video_file, "rb") as f:
28
  file_bytes = f.read()
 
32
  Part(file_data=file_bytes, mime_type="video/mp4"),
33
  Part(text=prompt)
34
  ],
35
+ config=config
36
  )
37
  return response.text
38
 
 
43
  parts = time_str.strip().split(":")
44
  parts = [float(p) for p in parts]
45
  if len(parts) == 3:
46
+ return parts[0] * 3600 + parts[1] * 60 + parts[2]
47
  elif len(parts) == 2:
48
+ return parts[0] * 60 + parts[1]
49
  else:
50
  return parts[0]
51
 
 
56
 
57
  Returns a list of tuples: (image_array, caption)
58
  """
59
+ # Define a response schema for key frames.
60
+ response_schema = {
61
+ "type": "ARRAY",
62
+ "items": {
63
+ "type": "OBJECT",
64
+ "properties": {
65
+ "timestamp": {"type": "string"},
66
+ "description": {"type": "string"}
67
+ },
68
+ "required": ["timestamp", "description"]
69
+ }
70
+ }
71
+ config = GenerateContentConfig(
72
+ temperature=0.0,
73
+ max_output_tokens=1024,
74
+ response_mime_type="application/json",
75
+ response_schema=response_schema
76
+ )
77
  prompt = (
78
+ "From the following video analysis, list key frames with their timestamps (in HH:MM:SS format) "
79
+ "and a brief description of the important event at that timestamp. "
80
+ "Return the result as a JSON array of objects with keys 'timestamp' and 'description'."
 
81
  )
82
  prompt += f" Video Analysis: {analysis}"
83
  if user_query:
84
  prompt += f" Additional focus: {user_query}"
85
 
86
  try:
87
+ key_frames_response = call_gemini(video_file, prompt, config=config)
88
  key_frames = json.loads(key_frames_response)
89
  if not isinstance(key_frames, list):
90
  key_frames = []
 
105
  seconds = hhmmss_to_seconds(ts)
106
  except Exception:
107
  continue
108
+ # Set video position (in milliseconds)
109
  cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
110
  ret, frame = cap.read()
111
  if ret: