|
import os |
|
import shutil |
|
from tqdm import tqdm |
|
import cv2 |
|
import gradio as gr |
|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
|
import math |
|
import zipfile |
|
from utils import plot_predictions, mp4_to_png, vid_stitcher |
|
|
|
def owl_batch_video( |
|
input_vids: list[str], |
|
target_prompt: list[str], |
|
species_prompt: str, |
|
threshold: float, |
|
fps_processed: int = 1, |
|
scaling_factor: float = 0.5, |
|
batch_size: int = 8, |
|
save_dir: str = "temp/", |
|
progress=gr.Progress() |
|
): |
|
pos_preds = [] |
|
neg_preds = [] |
|
|
|
df = pd.DataFrame(columns=["video path", "detection?"]) |
|
|
|
for vid in progress.tqdm(input_vids, desc="Processing videos"): |
|
detection = owl_video_detection(vid, |
|
target_prompt, |
|
species_prompt, |
|
threshold, |
|
fps_processed=fps_processed, |
|
scaling_factor=scaling_factor, |
|
batch_size=batch_size, |
|
save_dir=save_dir) |
|
|
|
if detection == True: |
|
pos_preds.append(vid) |
|
row = pd.DataFrame({"video path": [vid], "detection?": ["True"]}) |
|
df = pd.concat([df, row], ignore_index=True) |
|
else: |
|
neg_preds.append(vid) |
|
row = pd.DataFrame({"video path": [vid], "detection?": ["False"]}) |
|
df = pd.concat([df, row], ignore_index=True) |
|
|
|
|
|
df.to_csv(f"{save_dir}/detection_results.csv") |
|
|
|
|
|
zip_file = f"results.zip" |
|
zip_directory(save_dir, zip_file) |
|
|
|
return zip_file |
|
|
|
|
|
|
|
def zip_directory(folder_path, output_zip_path): |
|
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
for root, dirs, files in os.walk(folder_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
|
|
arcname = os.path.relpath(file_path, start=folder_path) |
|
zipf.write(file_path, arcname) |
|
|
|
|
|
def preprocess_text(text_prompt: str, num_prompts: int = 1): |
|
""" |
|
Takes a string of text prompts and returns a list of lists of text prompts for each image. |
|
i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]] |
|
""" |
|
text_prompt = [s.strip() for s in text_prompt.split(",")] |
|
text_queries = [text_prompt] * num_prompts |
|
|
|
return text_queries |
|
|
|
def owl_batch_prediction( |
|
images: torch.Tensor, |
|
text_queries : list[str], |
|
threshold: float, |
|
processor, |
|
model, |
|
device: str = 'cuda' |
|
): |
|
|
|
inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) |
|
|
|
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) |
|
|
|
return results |
|
|
|
|
|
def count_pos(phrases: list[str], text_targets: list[str]) -> int: |
|
""" |
|
Counts how many phrases in the list match any of the target phrases. |
|
|
|
Args: |
|
phrases: A list of strings to evaluate. |
|
text_targets: A list of target strings to match against. |
|
|
|
Returns: |
|
The number of phrases that match any of the targets. |
|
""" |
|
if len(phrases) == 0 or len(text_targets) == 0: |
|
return 0 |
|
target_set = set(text_targets) |
|
return sum(1 for phrase in phrases if phrase in target_set) |
|
|
|
|
|
def owl_video_detection( |
|
vid_path: str, |
|
text_target: list[str], |
|
text_prompt: str, |
|
threshold: float, |
|
fps_processed: int = 1, |
|
scaling_factor: float = 0.5, |
|
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"), |
|
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'), |
|
device: str = 'cuda', |
|
batch_size: int = 8, |
|
save_dir: str = "temp/", |
|
): |
|
""" |
|
Runs owl on a video and saves the results to a dataframe. |
|
Returns True if text_target is detected in the video, False otherwise. |
|
Stops running owl when a text_target is detected. |
|
""" |
|
os.makedirs(save_dir, exist_ok=True) |
|
os.makedirs(f"{save_dir}/positives", exist_ok=True) |
|
os.makedirs(f"{save_dir}/negatives", exist_ok=True) |
|
|
|
|
|
df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels", "count"]) |
|
|
|
|
|
filename = os.path.splitext(os.path.basename(vid_path))[0] |
|
frames_dir = f"{save_dir}/{filename}_frames" |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
|
|
fps = mp4_to_png(vid_path, frames_dir, scaling_factor) |
|
|
|
|
|
frame_filenames = os.listdir(frames_dir) |
|
|
|
frame_paths = [] |
|
|
|
for i, frame in enumerate(frame_filenames): |
|
if i % fps_processed == 0: |
|
frame_paths.append(os.path.join(frames_dir, frame)) |
|
|
|
|
|
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): |
|
frame_nums = [i*fps_processed for i in range(batch_size)] |
|
batch_paths = frame_paths[i:i+batch_size] |
|
images = [Image.open(image_path) for image_path in batch_paths] |
|
|
|
|
|
text_queries = preprocess_text(text_prompt, len(batch_paths)) |
|
results = owl_batch_prediction(images, text_queries, threshold, processor, model, device) |
|
|
|
|
|
label_ids = [] |
|
for entry in results: |
|
if entry['labels'].numel() > 0: |
|
label_ids.append(entry['labels'].tolist()) |
|
else: |
|
label_ids.append(None) |
|
|
|
text = text_queries[0] |
|
labels = [] |
|
|
|
for idx in label_ids: |
|
if idx is not None: |
|
idx = [text[id] for id in idx] |
|
labels.append(idx) |
|
else: |
|
labels.append([]) |
|
|
|
batch_pos = 0 |
|
for j, image in enumerate(batch_paths): |
|
boxes = results[j]['boxes'].cpu().numpy() |
|
scores = results[j]['scores'].cpu().numpy() |
|
count = count_pos(labels[j], text_target) |
|
row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]], "count": count}) |
|
df = pd.concat([df, row], ignore_index=True) |
|
|
|
|
|
if count > 0: |
|
annotated_frame = plot_predictions(image, labels[j], scores, boxes) |
|
cv2.imwrite(image, annotated_frame) |
|
batch_pos += 1 |
|
|
|
|
|
if batch_pos > math.ceil(2/3*batch_size): |
|
vid_stitcher(frames_dir, f"{save_dir}/positives/{filename}_{threshold}.mp4", fps) |
|
shutil.rmtree(frames_dir) |
|
df.to_csv(f"{save_dir}/positives/{filename}_{threshold}.csv", index=False) |
|
return True |
|
|
|
shutil.rmtree(frames_dir) |
|
df.to_csv(f"{save_dir}/negatives/{filename}_{threshold}.csv", index=False) |
|
return False |
|
|
|
|