VisionScout / image_processor.py
DawnC's picture
Upload 2 files
1487b33 verified
raw
history blame
13 kB
import os
import numpy as np
import torch
import cv2
from PIL import Image
import tempfile
import uuid
from typing import Dict, List, Any, Optional, Tuple
from detection_model import DetectionModel
from color_mapper import ColorMapper
from visualization_helper import VisualizationHelper
from evaluation_metrics import EvaluationMetrics
class ImageProcessor:
"""
Class for handling image processing and object detection operations
Separates processing logic from UI components
"""
def __init__(self):
"""Initialize the image processor with required components"""
self.color_mapper = ColorMapper()
self.model_instances = {}
def get_model_instance(self, model_name: str, confidence: float = 0.25, iou: float = 0.35) -> DetectionModel:
"""
Get or create a model instance based on model name
Args:
model_name: Name of the model to use
confidence: Confidence threshold for detection
iou: IoU threshold for non-maximum suppression
Returns:
DetectionModel instance
"""
if model_name not in self.model_instances:
print(f"Creating new model instance for {model_name}")
self.model_instances[model_name] = DetectionModel(
model_name=model_name,
confidence=confidence,
iou=iou
)
else:
print(f"Using existing model instance for {model_name}")
self.model_instances[model_name].confidence = confidence
return self.model_instances[model_name]
def process_image(self, image, model_name: str, confidence_threshold: float, filter_classes: Optional[List[int]] = None) -> Tuple[Any, str, Dict]:
"""
Process an image for object detection
Args:
image: Input image (numpy array or PIL Image)
model_name: Name of the model to use
confidence_threshold: Confidence threshold for detection
filter_classes: Optional list of classes to filter results
Returns:
Tuple of (result_image, result_text, stats_data)
"""
# Get model instance
model_instance = self.get_model_instance(model_name, confidence_threshold)
# Initialize key variables
result = None
stats = {}
temp_path = None
try:
# Processing input image
if isinstance(image, np.ndarray):
# Convert BGR to RGB if needed
if image.shape[2] == 3:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
image_rgb = image
pil_image = Image.fromarray(image_rgb)
elif image is None:
return None, "No image provided. Please upload an image.", {}
else:
pil_image = image
# Store temp files
temp_dir = tempfile.gettempdir() # Use system temp directory
temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
temp_path = os.path.join(temp_dir, temp_filename)
pil_image.save(temp_path)
# Object detection
result = model_instance.detect(temp_path)
if result is None:
return None, "Detection failed. Please try again with a different image.", {}
# Calculate stats
stats = EvaluationMetrics.calculate_basic_stats(result)
# Add space calculation
spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
stats["spatial_metrics"] = spatial_metrics
# Apply filter if specified
if filter_classes and len(filter_classes) > 0:
# Get classes, boxes, confidence
classes = result.boxes.cls.cpu().numpy().astype(int)
confs = result.boxes.conf.cpu().numpy()
boxes = result.boxes.xyxy.cpu().numpy()
mask = np.zeros_like(classes, dtype=bool)
for cls_id in filter_classes:
mask = np.logical_or(mask, classes == cls_id)
filtered_stats = {
"total_objects": int(np.sum(mask)),
"class_statistics": {},
"average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
"spatial_metrics": stats["spatial_metrics"]
}
# Update stats
names = result.names
for cls, conf in zip(classes[mask], confs[mask]):
cls_name = names[int(cls)]
if cls_name not in filtered_stats["class_statistics"]:
filtered_stats["class_statistics"][cls_name] = {
"count": 0,
"average_confidence": 0
}
filtered_stats["class_statistics"][cls_name]["count"] += 1
filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
stats = filtered_stats
viz_data = EvaluationMetrics.generate_visualization_data(
result,
self.color_mapper.get_all_colors()
)
result_image = VisualizationHelper.visualize_detection(
temp_path, result, color_mapper=self.color_mapper, figsize=(12, 12), return_pil=True
)
result_text = EvaluationMetrics.format_detection_summary(viz_data)
return result_image, result_text, stats
except Exception as e:
error_message = f"Error Occurs: {str(e)}"
import traceback
traceback.print_exc()
print(error_message)
return None, error_message, {}
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception as e:
print(f"Cannot delete temp files {temp_path}: {str(e)}")
def format_result_text(self, stats: Dict) -> str:
"""
Format detection statistics into readable text with improved spacing
Args:
stats: Dictionary containing detection statistics
Returns:
Formatted text summary
"""
if not stats or "total_objects" not in stats:
return "No objects detected."
# 減少不必要的空行
lines = [
f"Detected {stats['total_objects']} objects.",
f"Average confidence: {stats.get('average_confidence', 0):.2f}",
"Objects by class:"
]
if "class_statistics" in stats and stats["class_statistics"]:
# 按計數排序類別
sorted_classes = sorted(
stats["class_statistics"].items(),
key=lambda x: x[1]["count"],
reverse=True
)
for cls_name, cls_stats in sorted_classes:
count = cls_stats["count"]
conf = cls_stats.get("average_confidence", 0)
item_text = "item" if count == 1 else "items"
lines.append(f"• {cls_name}: {count} {item_text} (avg conf: {conf:.2f})")
else:
lines.append("No class information available.")
# 添加空間信息
if "spatial_metrics" in stats and "spatial_distribution" in stats["spatial_metrics"]:
lines.append("Object Distribution:")
dist = stats["spatial_metrics"]["spatial_distribution"]
x_mean = dist.get("x_mean", 0)
y_mean = dist.get("y_mean", 0)
# 描述物體的大致位置
if x_mean < 0.33:
h_pos = "on the left side"
elif x_mean < 0.67:
h_pos = "in the center"
else:
h_pos = "on the right side"
if y_mean < 0.33:
v_pos = "in the upper part"
elif y_mean < 0.67:
v_pos = "in the middle"
else:
v_pos = "in the lower part"
lines.append(f"• Most objects appear {h_pos} {v_pos} of the image")
return "\n".join(lines)
def format_json_for_display(self, stats: Dict) -> Dict:
"""
Format statistics JSON for better display
Args:
stats: Raw statistics dictionary
Returns:
Formatted statistics structure for display
"""
# Create a cleaner copy of the stats for display
display_stats = {}
# Add summary section
display_stats["summary"] = {
"total_objects": stats.get("total_objects", 0),
"average_confidence": round(stats.get("average_confidence", 0), 3)
}
# Add class statistics in a more organized way
if "class_statistics" in stats and stats["class_statistics"]:
# Sort classes by count (descending)
sorted_classes = sorted(
stats["class_statistics"].items(),
key=lambda x: x[1].get("count", 0),
reverse=True
)
class_stats = {}
for cls_name, cls_data in sorted_classes:
class_stats[cls_name] = {
"count": cls_data.get("count", 0),
"average_confidence": round(cls_data.get("average_confidence", 0), 3)
}
display_stats["detected_objects"] = class_stats
# Simplify spatial metrics
if "spatial_metrics" in stats:
spatial = stats["spatial_metrics"]
# Simplify spatial distribution
if "spatial_distribution" in spatial:
dist = spatial["spatial_distribution"]
display_stats["spatial"] = {
"distribution": {
"x_mean": round(dist.get("x_mean", 0), 3),
"y_mean": round(dist.get("y_mean", 0), 3),
"x_std": round(dist.get("x_std", 0), 3),
"y_std": round(dist.get("y_std", 0), 3)
}
}
# Add simplified size information
if "size_distribution" in spatial:
size = spatial["size_distribution"]
display_stats["spatial"]["size"] = {
"mean_area": round(size.get("mean_area", 0), 3),
"min_area": round(size.get("min_area", 0), 3),
"max_area": round(size.get("max_area", 0), 3)
}
return display_stats
def prepare_visualization_data(self, stats: Dict, available_classes: Dict[int, str]) -> Dict:
"""
Prepare data for visualization based on detection statistics
Args:
stats: Detection statistics
available_classes: Dictionary of available class IDs and names
Returns:
Visualization data dictionary
"""
if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
return {"error": "No detection data available"}
# Prepare visualization data
viz_data = {
"total_objects": stats.get("total_objects", 0),
"average_confidence": stats.get("average_confidence", 0),
"class_data": []
}
# Class data
for cls_name, cls_stats in stats.get("class_statistics", {}).items():
# Search class ID
class_id = -1
for id, name in available_classes.items():
if name == cls_name:
class_id = id
break
cls_data = {
"name": cls_name,
"class_id": class_id,
"count": cls_stats.get("count", 0),
"average_confidence": cls_stats.get("average_confidence", 0),
"color": self.color_mapper.get_color(class_id if class_id >= 0 else cls_name)
}
viz_data["class_data"].append(cls_data)
# Descending order
viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
return viz_data