Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
import re | |
import spacy | |
from config import LOGS_DIR, OUTPUT_DIR | |
from DepthEstimator import DepthEstimator | |
from SoundMapper import SoundMapper | |
from GenerateCaptions import generate_caption | |
from GenerateCaptions import StreetSoundTextPipeline, ImageAnalyzer | |
class ProcessVisualizer: | |
def __init__(self, image_dir=LOGS_DIR, output_dir=None): | |
self.image_dir = image_dir | |
self.output_dir = output_dir if output_dir else os.path.join(OUTPUT_DIR, "visualizations") | |
os.makedirs(self.output_dir, exist_ok=True) | |
# Initialize components (but don't load models yet) | |
self.depth_estimator = DepthEstimator(image_dir=self.image_dir) | |
self.sound_mapper = SoundMapper() | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.dino = None | |
self.dino_processor = None | |
self.nlp = None | |
# Create subdirectories for different visualization types | |
self.dirs = { | |
"bbox_original": os.path.join(self.output_dir, "bbox_original"), | |
"bbox_depth": os.path.join(self.output_dir, "bbox_depth"), | |
"depth_maps": os.path.join(self.output_dir, "depth_maps"), | |
"combined": os.path.join(self.output_dir, "combined") | |
} | |
for dir_path in self.dirs.values(): | |
os.makedirs(dir_path, exist_ok=True) | |
def _load_nlp(self): | |
if self.nlp is None: | |
self.nlp = spacy.load("en_core_web_sm") | |
return self.nlp | |
def _load_dino(self): | |
if self.dino is None: | |
print("Loading DINO model...") | |
self.dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(self.device) | |
self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") | |
else: | |
self.dino = self.dino.to(self.device) | |
return self.dino, self.dino_processor | |
def _unload_dino(self): | |
if self.dino is not None: | |
self.dino = self.dino.to("cpu") | |
torch.cuda.empty_cache() | |
def detect_nouns(self, caption_text): | |
"""Extract nouns from caption text for object detection""" | |
print("Detecting nouns in caption...") | |
nlp = self._load_nlp() | |
all_nouns = [] | |
# Extract nouns from sound source descriptions | |
pattern = r'\d+\.\s+\*\*([^:]+)\*\*:' | |
sources = re.findall(pattern, caption_text) | |
for source in sources: | |
clean_source = re.sub(r'sounds?|noise[s]?', '', source, flags=re.IGNORECASE).strip() | |
if clean_source: | |
source_doc = nlp(clean_source) | |
for token in source_doc: | |
if token.pos_ == "NOUN" and len(token.text) > 1: | |
all_nouns.append(token.text.lower()) | |
# Extract nouns from general text | |
clean_caption = re.sub(r'[*()]', '', caption_text).strip() | |
clean_caption = re.sub(r'##\w+', '', clean_caption) | |
clean_caption = re.sub(r'\s+', ' ', clean_caption).strip() | |
doc = nlp(clean_caption) | |
for token in doc: | |
if token.pos_ == "NOUN" and len(token.text) > 1: | |
if token.text[0].isalpha(): | |
all_nouns.append(token.text.lower()) | |
matches = sorted(set(all_nouns)) | |
print(f"Detected nouns: {matches}") | |
return matches | |
def detect_objects(self, image_path, caption_text): | |
"""Detect objects in image based on nouns from caption""" | |
print(f"Processing image: {image_path}") | |
# Extract nouns from caption | |
nouns = self.detect_nouns(caption_text) | |
if not nouns: | |
print("No nouns detected in caption.") | |
return None, None | |
# Load image | |
image = Image.open(image_path) | |
# Load DINO model | |
self.dino, self.dino_processor = self._load_dino() | |
# Filter nouns | |
filtered_nouns = [] | |
for noun in nouns: | |
if '##' not in noun and len(noun) > 1 and noun[0].isalpha(): | |
filtered_nouns.append(noun) | |
# Create text prompt for DINO | |
text_prompt = " . ".join(filtered_nouns) | |
print(f"Using text prompt for DINO: {text_prompt}") | |
# Process image with DINO | |
inputs = self.dino_processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.dino(**inputs) | |
results = self.dino_processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold=0.25, | |
text_threshold=0.25, | |
target_sizes=[image.size[::-1]] | |
) | |
# Clean up to save memory | |
self._unload_dino() | |
del inputs, outputs | |
torch.cuda.empty_cache() | |
# Process results | |
result = results[0] | |
labels = result["labels"] | |
scores = result["scores"] | |
bboxes = result["boxes"] | |
# Clean labels | |
clean_labels = [] | |
for label in labels: | |
clean_label = re.sub(r'##\w+', '', label) | |
clean_labels.append(clean_label) | |
print(f"Detected {len(clean_labels)} objects: {list(zip(clean_labels, scores.tolist()))}") | |
return clean_labels, bboxes | |
def estimate_depth(self): | |
"""Generate depth maps for all images in the directory""" | |
print("Estimating depth for all images...") | |
depth_maps = self.depth_estimator.estimate_depth(self.image_dir) | |
# Convert depth maps to normalized grayscale for visualization | |
normalized_maps = [] | |
img_paths = [os.path.join(self.image_dir, f) for f in os.listdir(self.image_dir) | |
if f.endswith(('.jpg', '.jpeg', '.png'))] | |
for i, item in enumerate(depth_maps): | |
depth_map = item["depth"] | |
depth_array = np.array(depth_map) | |
normalization = depth_array / 255.0 | |
# Associate source path with depth map | |
source_path = img_paths[i] if i < len(img_paths) else f"depth_{i}.jpg" | |
filename = os.path.basename(source_path) | |
# Save grayscale depth map | |
depth_path = os.path.join(self.dirs["depth_maps"], f"depth_{filename}") | |
depth_map.save(depth_path) | |
normalized_maps.append({ | |
"original": depth_map, | |
"normalization": normalization, | |
"path": depth_path, | |
"source_path": source_path | |
}) | |
return normalized_maps | |
def create_histogram_depth_zones(self, depth_map, num_zones=3): | |
"""Create depth zones based on histogram of depth values""" | |
hist, bin_edge = np.histogram(depth_map.flatten(), bins=50, range=(0, 1)) | |
cumulative = np.cumsum(hist) / np.sum(hist) | |
thresholds = [0.0] | |
for i in range(1, num_zones): | |
target = i / num_zones | |
idx = np.argmin(np.abs(cumulative - target)) | |
thresholds.append(bin_edge[idx + 1]) | |
thresholds.append(1.0) | |
return thresholds | |
def get_depth_zone(self, bbox, depth_map, num_zones=3): | |
"""Determine depth zone for a given bounding box""" | |
x1, y1, x2, y2 = [int(coord) for coord in bbox] | |
# Adjust for image dimensions | |
height, width = depth_map.shape | |
x1, y1 = max(0, x1), max(0, y1) | |
x2, y2 = min(width, x2), min(height, y2) | |
# Extract depth ROI | |
depth_roi = depth_map[y1:y2, x1:x2] | |
if depth_roi.size == 0: | |
return num_zones - 1, 1.0 # Default to farthest zone | |
# Calculate mean depth | |
mean_depth = np.mean(depth_roi) | |
# Determine zone | |
thresholds = self.create_histogram_depth_zones(depth_map, num_zones) | |
zone = 0 | |
for i in range(num_zones): | |
if thresholds[i] <= mean_depth < thresholds[i+1]: | |
zone = i | |
break | |
weight = 1.0 - mean_depth # Higher weight for closer objects | |
return zone, mean_depth | |
def draw_bounding_boxes(self, image, labels, bboxes, scores=None, depth_zones=None): | |
"""Draw bounding boxes on image with depth zone information""" | |
draw = ImageDraw.Draw(image) | |
# Try to get a font, fallback to default if not available | |
try: | |
font = ImageFont.truetype("arial.ttf", 16) | |
except IOError: | |
try: | |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16) | |
except: | |
font = ImageFont.load_default() | |
# Store colors as a class attribute for access in modified versions | |
self.zone_colors = { | |
0: (255, 50, 50), # Bright red for near | |
1: (255, 180, 0), # Orange for medium | |
2: (50, 255, 50) # Bright green for far | |
} | |
for i, (label, bbox) in enumerate(zip(labels, bboxes)): | |
x1, y1, x2, y2 = [int(coord) for coord in bbox] | |
# Get color based on depth zone if available | |
if depth_zones is not None and i < len(depth_zones): | |
zone, depth = depth_zones[i] | |
color = self.zone_colors.get(zone, (0, 0, 255)) | |
zone_text = ["near", "medium", "far"][zone] | |
label_text = f"{depth:.2f}" | |
else: | |
color = (255, 50, 50) # Default bright red | |
label_text = label | |
# Add score if available | |
if scores is not None and i < len(scores): | |
label_text += f" {scores[i]:.2f}" | |
# Draw bounding box with thick border for better visibility | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
# Calculate text size more reliably | |
if hasattr(draw, 'textsize'): | |
text_size = draw.textsize(label_text, font=font) | |
else: | |
# Fallback sizing when textsize is not available | |
text_width = len(label_text) * 8 # Approximate 8 pixels per character | |
text_height = 20 # Approximate height for readability | |
text_size = (text_width, text_height) | |
# Draw label background with margin | |
margin = 2 | |
text_box = [ | |
x1 - margin, | |
y1 - text_size[1] - margin, | |
x1 + text_size[0] + margin, | |
y1 + margin | |
] | |
draw.rectangle(text_box, fill=color) | |
# Draw label text | |
draw.text((x1, y1 - text_size[1]), label_text, fill=(255, 255, 255), font=font) | |
return image | |
def create_depth_map_visualization(self, depth_map, use_grayscale=True): | |
"""Create a visualization of the depth map | |
Args: | |
depth_map: Normalized depth map array | |
use_grayscale: If True, creates grayscale image; otherwise, uses colored heatmap | |
Returns: | |
PIL Image with depth visualization | |
""" | |
# Normalize depth map to [0, 1] | |
normalized_depth = depth_map.copy() | |
if use_grayscale: | |
# Convert to grayscale (multiplying by 255 for better visibility) | |
grayscale = (normalized_depth * 255).astype(np.uint8) | |
# Convert to RGB for consistent processing with bounding box drawing | |
depth_img = Image.fromarray(grayscale).convert('RGB') | |
else: | |
# Apply colormap (jet) | |
colored_depth = (cm.jet(normalized_depth) * 255).astype(np.uint8) | |
# Convert to PIL Image (RGB) | |
depth_img = Image.fromarray(colored_depth[:, :, :3]) | |
return depth_img | |
def process_images(self, lat=None, lon=None, single_view=None, save_with_heatmap=False): | |
""" | |
Process all images in the directory or a single view | |
Args: | |
lat: Latitude for caption generation | |
lon: Longitude for caption generation | |
single_view: Process only specified view if provided | |
save_with_heatmap: If True, also saves depth maps as colored heatmaps | |
""" | |
# Get image paths | |
if single_view: | |
image_paths = [os.path.join(self.image_dir, f"{single_view}.jpg")] | |
else: | |
image_paths = [os.path.join(self.image_dir, f) for f in os.listdir(self.image_dir) | |
if f.endswith(('.jpg', '.jpeg', '.png'))] | |
if not image_paths: | |
print(f"No images found in {self.image_dir}") | |
return | |
# Generate depth maps | |
depth_maps = self.estimate_depth() | |
# Process each image | |
for i, image_path in enumerate(image_paths): | |
image_basename = os.path.basename(image_path) | |
view_name = os.path.splitext(image_basename)[0] | |
print(f"\nProcessing {view_name} view ({i+1}/{len(image_paths)})...") | |
# Generate caption if coordinates are provided | |
caption_text = None | |
analyzer = ImageAnalyzer() | |
caption_text = analyzer.analyze_image(image_path) | |
if lat and lon: | |
view_result = generate_caption(lat, lon, view=view_name, panoramic=False) | |
if view_result: | |
caption_text = view_result.get("sound_description", "") | |
print(f"Generated caption: {caption_text}") | |
# Skip if no caption and lat/lon were provided | |
if lat and lon and not caption_text: | |
print(f"Failed to generate caption for {image_path}, skipping.") | |
continue | |
# Detect objects based on caption | |
if caption_text: | |
labels, bboxes = self.detect_objects(image_path, caption_text) | |
else: | |
# If no caption provided, use generic object detection | |
print("No caption provided, using predefined nouns for detection...") | |
generic_nouns = ["car", "person", "tree", "building", "road", "sign", "window", "door"] | |
labels, bboxes = self.detect_objects(image_path, " ".join(generic_nouns)) | |
if len(labels) == 0 or len(bboxes)==0: | |
print(f"No objects detected in {image_path}, skipping.") | |
continue | |
# Find matching depth map | |
depth_map_idx = next((idx for idx, data in enumerate(depth_maps) | |
if os.path.basename(image_path) == os.path.basename(data.get("source_path", ""))), i % len(depth_maps)) | |
depth_map = depth_maps[depth_map_idx]["normalization"] | |
# Get depth zones for each detected object | |
depth_zones = [] | |
for bbox in bboxes: | |
zone, mean_depth = self.get_depth_zone(bbox, depth_map) | |
depth_zones.append((zone, mean_depth)) | |
# Load and process original image | |
original_img = Image.open(image_path).convert("RGB") | |
bbox_img = original_img.copy() | |
# Draw bounding boxes on original image | |
bbox_img = self.draw_bounding_boxes(bbox_img, labels, bboxes, depth_zones=depth_zones) | |
# Save image with bounding boxes | |
bbox_path = os.path.join(self.dirs["bbox_original"], f"bbox_{image_basename}") | |
bbox_img.save(bbox_path) | |
print(f"Saved bounding boxes on original image: {bbox_path}") | |
# Create grayscale depth map for better visibility of bounding boxes | |
depth_vis = self.create_depth_map_visualization(depth_map, use_grayscale=True) | |
# Draw bounding boxes on depth map visualization | |
depth_bbox_img = depth_vis.copy() | |
depth_bbox_img = self.draw_bounding_boxes(depth_bbox_img, labels, bboxes, depth_zones=depth_zones) | |
# Draw bounding boxes directly on the original depth map | |
# Load the saved grayscale depth map | |
original_depth_path = depth_maps[depth_map_idx]["path"] | |
original_depth_img = Image.open(original_depth_path).convert('RGB') | |
# Draw boxes on the original depth map | |
original_depth_bbox = original_depth_img.copy() | |
original_depth_bbox = self.draw_bounding_boxes(original_depth_bbox, labels, bboxes, depth_zones=depth_zones) | |
# Save the original depth map with bounding boxes | |
original_depth_bbox_path = os.path.join(self.dirs["bbox_depth"], f"orig_depth_bbox_{image_basename}") | |
original_depth_bbox.save(original_depth_bbox_path) | |
print(f"Saved bounding boxes on original depth map: {original_depth_bbox_path}") | |
# Save depth map with bounding boxes | |
depth_bbox_path = os.path.join(self.dirs["bbox_depth"], f"depth_bbox_{image_basename}") | |
depth_bbox_img.save(depth_bbox_path) | |
print(f"Saved bounding boxes on depth map: {depth_bbox_path}") | |
# Also save colored heatmap version if requested | |
if save_with_heatmap: | |
# Create a heatmap depth visualization | |
depth_heatmap = self.create_depth_map_visualization(depth_map, use_grayscale=False) | |
depth_heatmap_bbox = depth_heatmap.copy() | |
depth_heatmap_bbox = self.draw_bounding_boxes(depth_heatmap_bbox, labels, bboxes, depth_zones=depth_zones) | |
# Save heatmap version | |
heatmap_path = os.path.join(self.dirs["bbox_depth"], f"heatmap_bbox_{image_basename}") | |
depth_heatmap_bbox.save(heatmap_path) | |
print(f"Saved bounding boxes on depth heatmap: {heatmap_path}") | |
# Create combined visualization | |
# Create a 2x1 grid showing original with bboxes and original depth with bboxes | |
combined_width = original_img.width * 2 | |
combined_height = original_img.height | |
combined_img = Image.new('RGB', (combined_width, combined_height)) | |
# Paste images | |
combined_img.paste(bbox_img, (0, 0)) | |
combined_img.paste(original_depth_bbox, (original_img.width, 0)) | |
# Save combined image | |
combined_path = os.path.join(self.dirs["combined"], f"combined_{image_basename}") | |
combined_img.save(combined_path) | |
print(f"Saved combined visualization: {combined_path}") | |
print("\nVisualization process complete!") | |
print(f"Results saved in {self.output_dir}") | |
def cleanup(self): | |
"""Clean up resources""" | |
if hasattr(self, 'depth_estimator'): | |
self.depth_estimator._unload_model() | |
if self.dino is not None: | |
self.dino = self.dino.to("cpu") | |
del self.dino | |
self.dino = None | |
if self.nlp is not None: | |
del self.nlp | |
self.nlp = None | |
torch.cuda.empty_cache() | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser(description="Visualize intermediate steps of the Street Sound Pipeline") | |
parser.add_argument("--image_dir", type=str, default=LOGS_DIR, help="Directory containing input images") | |
parser.add_argument("--output_dir", type=str, default=None, help="Directory for output visualizations") | |
parser.add_argument("--location", type=str, default=None, help='Location in format "latitude,longitude" (e.g., "40.7128,-74.0060")') | |
parser.add_argument("--view", type=str, default=None, choices=["front", "back", "left", "right"], help="Process only the specified view") | |
parser.add_argument("--skip_caption", action="store_true", help="Skip caption generation and use generic noun list") | |
parser.add_argument("--save_heatmap", action="store_true", help="Also save depth maps as colored heatmaps with bounding boxes") | |
parser.add_argument("--box_width", type=int, default=3, help="Width of bounding box lines") | |
args = parser.parse_args() | |
# Parse location if provided | |
lat, lon = None, None | |
if args.location and not args.skip_caption: | |
try: | |
lat, lon = map(float, args.location.split(",")) | |
except ValueError: | |
print("Error: Location must be in format 'latitude,longitude'") | |
return | |
# Initialize visualizer | |
visualizer = ProcessVisualizer(image_dir=args.image_dir, output_dir=args.output_dir) | |
# Set box width if provided | |
if args.box_width != 3: | |
draw_bounding_boxes_orig = visualizer.draw_bounding_boxes | |
def draw_bounding_boxes_with_width(*args, **kwargs): | |
draw = ImageDraw.Draw(args[0]) | |
for i, (label, bbox) in enumerate(zip(args[1], args[2])): | |
x1, y1, x2, y2 = [int(coord) for coord in bbox] | |
depth_zones = kwargs.get('depth_zones') | |
if depth_zones is not None and i < len(depth_zones): | |
zone, depth = depth_zones[i] | |
color = draw_bounding_boxes_orig.zone_colors.get(zone, (0, 0, 255)) | |
else: | |
color = (255, 0, 0) | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=args.box_width) | |
return draw_bounding_boxes_orig(*args, **kwargs) | |
visualizer.draw_bounding_boxes = draw_bounding_boxes_with_width | |
try: | |
# Process images | |
visualizer.process_images(lat=lat, lon=lon, single_view=args.view, save_with_heatmap=args.save_heatmap) | |
finally: | |
# Clean up resources | |
visualizer.cleanup() | |
if __name__ == "__main__": | |
main() |