annading commited on
Commit
ca1863b
·
1 Parent(s): e365924

first commit

Browse files
Files changed (6) hide show
  1. .gitignore +6 -0
  2. README.md +1 -1
  3. app_batch.py +134 -0
  4. owl_batch.py +210 -0
  5. requirements.txt +9 -0
  6. utils.py +103 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ */__pycache__/**
3
+ *.mp4
4
+ *.png
5
+ *.csv
6
+ /.gradio/
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.24.0
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.24.0
8
+ app_file: app_batch.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
app_batch.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BATCH_SIZE = 8 # Change this to your desired batch size
2
+ CUDA_PATH = "/usr/local/cuda-12.3/" # Change this to your CUDA path
3
+
4
+
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ # set CUDA_HOME
9
+ os.environ["CUDA_HOME"] = CUDA_PATH
10
+
11
+ import gradio as gr
12
+ from tqdm import tqdm
13
+ import os
14
+ import time
15
+
16
+ from owl_batch import owl_batch_video
17
+
18
+ # global CSV_PATH # csv that contains video names and detection results
19
+ # global POS_ZIP # zip of positive videos and individual results
20
+ # global NEG_ZIP # zip of negative videos and individual results
21
+
22
+ def run_owl_batch(
23
+ input_vids : list[str] | str,
24
+ target_prompt: str,
25
+ species_prompt: str,
26
+ conf_threshold: float,
27
+ fps_processed: int,
28
+ scaling_factor: float
29
+ ) -> tuple[str, str, str]:
30
+ """
31
+ args:
32
+ input_vids: list of video paths
33
+ target_prompt: prompt to search for
34
+ species_prompt: prompt to query
35
+ threshold: threshold for detection
36
+ fps_processed: number of frames per second to process
37
+ scaling_factor: factor to scale the frames by
38
+ returns:
39
+ csv_path: path to csv file
40
+ pos_zip: path to zip file of positive videos
41
+ neg_zip: path to zip file of negative videos
42
+ """
43
+ start_time = time.time()
44
+ if type(input_vids) == str:
45
+ input_vids = [input_vids]
46
+ for vid in input_vids:
47
+ new_input_vid = vid.replace(" ", "_") # make sure there are no spaces in the name
48
+ os.rename(vid, new_input_vid)
49
+
50
+ # species prompt has to contain target prompt, otherwise add it
51
+ if target_prompt not in species_prompt:
52
+ species_prompt = f"{species_prompt}, {target_prompt}"
53
+
54
+ # turn target prompt into a list
55
+ target_prompt = target_prompt.split(", ")
56
+
57
+ now = datetime.now()
58
+ timestamp = now.strftime("%Y-%m-%d_%H-%M")
59
+
60
+ zip_path = owl_batch_video(
61
+ input_vids,
62
+ target_prompt,
63
+ species_prompt,
64
+ conf_threshold,
65
+ fps_processed=fps_processed,
66
+ scaling_factor=1/scaling_factor,
67
+ batch_size=BATCH_SIZE,
68
+ save_dir=f"temp_{timestamp}")
69
+
70
+ end_time = time.time()
71
+ print(f'Processing time: {end_time - start_time} seconds')
72
+ return zip_path
73
+
74
+
75
+ with gr.Blocks() as demo:
76
+ gr.HTML(
77
+ """
78
+ <h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1>
79
+ """
80
+ )
81
+
82
+ with gr.Row():
83
+ with gr.Column():
84
+ input = gr.File(label="Upload Videos", file_types=['.mp4', '.mov'], file_count="multiple")
85
+ target_prompt = gr.Textbox(label="What do you want to detect? (Multiple species should be separated by commas)")
86
+ species_prompt = gr.Textbox(label="Which species are in your dataset? (Multiple species should be separated by commas)")
87
+ with gr.Accordion("Advanced Options", open=False):
88
+ conf_threshold = gr.Slider(
89
+ label="Confidence Threshold",
90
+ info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.",
91
+ minimum=0.0,
92
+ maximum=1.0,
93
+ value=0.3,
94
+ step=0.05
95
+ )
96
+ fps_processed = gr.Slider(
97
+ label="Frame Detection Rate",
98
+ info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.",
99
+ minimum=1,
100
+ maximum=120,
101
+ value=10,
102
+ step=1)
103
+ scaling_factor = gr.Slider(
104
+ label="Downsample Factor",
105
+ info="Adjust the downsample factor. Note: the higher the number the faster the processing time but lower the accuracy.",
106
+ minimum=1,
107
+ maximum=10,
108
+ value=4,
109
+ step=1
110
+ )
111
+ with gr.Row():
112
+ clear_btn = gr.ClearButton(components=[input, target_prompt, species_prompt])
113
+ run_btn = gr.Button(value="Run Detection", variant='primary')
114
+ with gr.Column():
115
+ download_file = gr.Files(label="CSV, Video Output", interactive=False)
116
+
117
+ run_btn.click(fn=run_owl_batch, inputs=[input, target_prompt, species_prompt, conf_threshold, fps_processed, scaling_factor], outputs=[download_file])
118
+
119
+ gr.DuplicateButton()
120
+
121
+ gr.Markdown(
122
+ """
123
+ ## Frequently Asked Questions
124
+
125
+ ##### How can I run the interface on my own computer?
126
+ By clicking on the three dots on the top right corner of the interface, you will be able to clone the repository or run it with a Docker image on your local machine. \
127
+ For local machine setup instructions please check the README file.
128
+ ##### The video is very slow to process, how can I speed it up?
129
+ You can speed up the processing by adjusting the frame detection rate in the advanced options. The lower the number the slower the processing time. Choosing only\
130
+ bounding boxes will make the processing faster. You can also duplicate the space using the Duplicate Button and choose a different GPU which will make the processing faster.
131
+ """
132
+ )
133
+
134
+ demo.launch(share=True)
owl_batch.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from tqdm import tqdm
4
+ import cv2
5
+ import pandas as pd
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
9
+ import math
10
+ import zipfile
11
+ from utils import plot_predictions, mp4_to_png, vid_stitcher
12
+
13
+ def owl_batch_video(
14
+ input_vids: list[str],
15
+ target_prompt: list[str],
16
+ species_prompt: str,
17
+ threshold: float,
18
+ fps_processed: int = 1,
19
+ scaling_factor: float = 0.5,
20
+ batch_size: int = 8,
21
+ save_dir: str = "temp/"
22
+ ):
23
+ pos_preds = []
24
+ neg_preds = []
25
+
26
+ df = pd.DataFrame(columns=["video path", "detection?"])
27
+
28
+ for vid in input_vids:
29
+ detection = owl_video_detection(vid,
30
+ target_prompt,
31
+ species_prompt,
32
+ threshold,
33
+ fps_processed=fps_processed,
34
+ scaling_factor=scaling_factor,
35
+ batch_size=batch_size,
36
+ save_dir=save_dir)
37
+
38
+ if detection == True:
39
+ pos_preds.append(vid)
40
+ row = pd.DataFrame({"video path": [vid], "detection?": ["True"]})
41
+ df = pd.concat([df, row], ignore_index=True)
42
+ else:
43
+ neg_preds.append(vid)
44
+ row = pd.DataFrame({"video path": [vid], "detection?": ["False"]})
45
+ df = pd.concat([df, row], ignore_index=True)
46
+
47
+ # save the df
48
+ df.to_csv(f"{save_dir}/detection_results.csv")
49
+
50
+ # zip the save_dir
51
+ zip_file = f"{save_dir}/results.zip"
52
+ zip_directory(save_dir, zip_file)
53
+
54
+ return zip_file
55
+
56
+
57
+
58
+ def zip_directory(folder_path, output_zip_path):
59
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
60
+ for root, dirs, files in os.walk(folder_path):
61
+ for file in files:
62
+ file_path = os.path.join(root, file)
63
+ # Write the file with a relative path to preserve folder structure
64
+ arcname = os.path.relpath(file_path, start=folder_path)
65
+ zipf.write(file_path, arcname)
66
+
67
+
68
+ def preprocess_text(text_prompt: str, num_prompts: int = 1):
69
+ """
70
+ Takes a string of text prompts and returns a list of lists of text prompts for each image.
71
+ i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]]
72
+ """
73
+ text_prompt = [s.strip() for s in text_prompt.split(",")]
74
+ text_queries = [text_prompt] * num_prompts
75
+ # print("text_queries:", text_queries)
76
+ return text_queries
77
+
78
+ def owl_batch_prediction(
79
+ images: torch.Tensor,
80
+ text_queries : list[str], # assuming that every image is queried with the same text prompt
81
+ threshold: float,
82
+ processor,
83
+ model,
84
+ device: str = 'cuda'
85
+ ):
86
+
87
+ inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device)
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+
91
+ # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
92
+ target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
93
+ # Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold
94
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold)
95
+
96
+ return results
97
+
98
+
99
+ def count_pos(phrases: list[str], text_targets: list[str]) -> int:
100
+ """
101
+ Counts how many phrases in the list match any of the target phrases.
102
+
103
+ Args:
104
+ phrases: A list of strings to evaluate.
105
+ text_targets: A list of target strings to match against.
106
+
107
+ Returns:
108
+ The number of phrases that match any of the targets.
109
+ """
110
+ if len(phrases) == 0 or len(text_targets) == 0:
111
+ return 0
112
+ target_set = set(text_targets)
113
+ return sum(1 for phrase in phrases if phrase in target_set)
114
+
115
+
116
+ def owl_video_detection(
117
+ vid_path: str,
118
+ text_target: list[str],
119
+ text_prompt: str,
120
+ threshold: float,
121
+ fps_processed: int = 1,
122
+ scaling_factor: float = 0.5,
123
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"),
124
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'),
125
+ device: str = 'cuda',
126
+ batch_size: int = 8,
127
+ save_dir: str = "temp/",
128
+ ):
129
+ """
130
+ Runs owl on a video and saves the results to a dataframe.
131
+ Returns True if text_target is detected in the video, False otherwise.
132
+ Stops running owl when a text_target is detected.
133
+ """
134
+ os.makedirs(save_dir, exist_ok=True)
135
+ os.makedirs(f"{save_dir}/positives", exist_ok=True)
136
+ os.makedirs(f"{save_dir}/negatives", exist_ok=True)
137
+
138
+ # set up df for results
139
+ df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels", "count"])
140
+
141
+ # create new dirs and paths for results
142
+ filename = os.path.splitext(os.path.basename(vid_path))[0]
143
+ frames_dir = f"{save_dir}/{filename}_frames"
144
+ os.makedirs(frames_dir, exist_ok=True)
145
+
146
+ # process video and create a directory of video frames
147
+ fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
148
+
149
+ # get all frame paths
150
+ frame_filenames = os.listdir(frames_dir)
151
+
152
+ frame_paths = [] # list of frame paths to process based on fps_processed
153
+ # for every frame processed, add to frame_paths
154
+ for i, frame in enumerate(frame_filenames):
155
+ if i % fps_processed == 0:
156
+ frame_paths.append(os.path.join(frames_dir, frame))
157
+
158
+ # run owl in batches
159
+ for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
160
+ frame_nums = [i*fps_processed for i in range(batch_size)]
161
+ batch_paths = frame_paths[i:i+batch_size] # paths for this batch
162
+ images = [Image.open(image_path) for image_path in batch_paths]
163
+
164
+ # run owl on this batch of frames
165
+ text_queries = preprocess_text(text_prompt, len(batch_paths))
166
+ results = owl_batch_prediction(images, text_queries, threshold, processor, model, device)
167
+
168
+ # get the boxes, logits, and phrases for this batch
169
+ label_ids = []
170
+ for entry in results:
171
+ if entry['labels'].numel() > 0:
172
+ label_ids.append(entry['labels'].tolist())
173
+ else:
174
+ label_ids.append(None)
175
+
176
+ text = text_queries[0] # assuming that all texts in query are the same for each image
177
+ labels = []
178
+ # convert label_ids to phrases, if no phrases, append None
179
+ for idx in label_ids:
180
+ if idx is not None:
181
+ idx = [text[id] for id in idx]
182
+ labels.append(idx)
183
+ else:
184
+ labels.append([])
185
+
186
+ batch_pos = 0
187
+ for j, image in enumerate(batch_paths):
188
+ boxes = results[j]['boxes'].cpu().numpy()
189
+ scores = results[j]['scores'].cpu().numpy()
190
+ count = count_pos(labels[j], text_target)
191
+ row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]], "count": count})
192
+ df = pd.concat([df, row], ignore_index=True)
193
+
194
+ # if there are detections, save the frame replacing the original frame
195
+ if count > 0:
196
+ annotated_frame = plot_predictions(image, labels[j], scores, boxes)
197
+ cv2.imwrite(image, annotated_frame)
198
+ batch_pos += 1
199
+
200
+ # if more than 2/3 batch frames are positive, return True
201
+ if batch_pos > math.ceil(2/3*batch_size):
202
+ vid_stitcher(frames_dir, f"{save_dir}/positives/{filename}_{threshold}.mp4", fps)
203
+ shutil.rmtree(frames_dir) # delete the frames to save space
204
+ df.to_csv(f"{save_dir}/positives/{filename}_{threshold}.csv", index=False)
205
+ return True
206
+
207
+ shutil.rmtree(frames_dir) # delete the frames to save space
208
+ df.to_csv(f"{save_dir}/negatives/{filename}_{threshold}.csv", index=False)
209
+ return False
210
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.9.1
2
+ numpy
3
+ opencv-python
4
+ pandas
5
+ Pillow
6
+ supervision
7
+ torch
8
+ tqdm
9
+ transformers
utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import subprocess
3
+ import numpy as np
4
+ import supervision as sv
5
+ import cv2
6
+ import os
7
+ from glob import glob
8
+ from tqdm import tqdm
9
+ import math
10
+
11
+
12
+ def plot_predictions(
13
+ image: str,
14
+ labels: list[str],
15
+ scores: list[float],
16
+ boxes: list[float],
17
+ opacity: float = 1.0
18
+ ) -> np.ndarray:
19
+
20
+ image_source = cv2.imread(image)
21
+ image_source_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
22
+
23
+ boxes = sv.Detections(xyxy=boxes)
24
+
25
+ labels = [
26
+ f"{phrase} {logit:.2f}"
27
+ for phrase, logit
28
+ in zip(labels, scores)
29
+ ]
30
+
31
+ height, width, _ = image_source_rgb.shape
32
+ thickness = math.ceil(width/200)
33
+ text_scale = width/1500
34
+ text_thickness = math.ceil(text_scale*1.5)
35
+
36
+ bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness)
37
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, text_thickness=text_thickness)
38
+
39
+ # Create a semi-transparent overlay
40
+ overlay = image_source_rgb.copy()
41
+
42
+ # Apply bounding box annotations to the overlay
43
+ overlay = bbox_annotator.annotate(scene=overlay, detections=boxes)
44
+ overlay = label_annotator.annotate(scene=overlay, detections=boxes, labels=labels)
45
+
46
+ # Blend overlay with original image using the specified opacity
47
+ annotated_frame = cv2.addWeighted(overlay, opacity, image_source_rgb, 1 - opacity, 0)
48
+
49
+ annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
50
+
51
+ return annotated_frame_bgr
52
+
53
+ def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str:
54
+ """ Converts mp4 to pngs for each frame of the video.
55
+ Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
56
+ Returns: save_path, fps the number of frames per second.
57
+ """
58
+ # get frames per second
59
+ fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
60
+ # run subprocess to convert mp4 to pngs
61
+ os.system(f"ffmpeg -i {input_path} -vf 'fps={fps},scale=iw*{scale_factor}:ih*{scale_factor}' {save_path}/frame%08d.png")
62
+ # subprocess.run(["ffmpeg", "-i", input_path, "-vf", f"scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}", f"{save_path}/frame%08d.png"])
63
+ return fps
64
+
65
+ def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
66
+ """
67
+ Takes a list of frames as numpy arrays and writes them to a video file.
68
+ """
69
+ # Get the list of frames
70
+ frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))
71
+
72
+ # Prepare the VideoWriter
73
+ frame = cv2.imread(frame_list[0])
74
+ height, width, _ = frame.shape
75
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
76
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
77
+
78
+ # Use multithreading to read frames faster
79
+ from concurrent.futures import ThreadPoolExecutor
80
+ with ThreadPoolExecutor() as executor:
81
+ frames = list(executor.map(cv2.imread, frame_list))
82
+
83
+ # Write frames to the video
84
+ with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
85
+ for frame in frames:
86
+ out.write(frame)
87
+ pbar.update(1)
88
+
89
+ return output_path
90
+
91
+ def count_pos(phrases, text_target):
92
+ """
93
+ Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
94
+ """
95
+ num_pos = 0
96
+ for sublist in phrases:
97
+ if sublist == None:
98
+ continue
99
+ for phrase in sublist:
100
+ if phrase == text_target:
101
+ num_pos += 1
102
+ break
103
+ return num_pos