Spaces:
Sleeping
Sleeping
import streamlit as st | |
# THIS MUST BE THE FIRST STREAMLIT COMMAND | |
st.set_page_config( | |
page_title="Pet Segmentation with SegFormer", | |
page_icon="🐶", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
import tensorflow as tf | |
from tensorflow.keras import backend | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
from PIL import Image | |
import os | |
import io | |
import gdown | |
from transformers import TFSegformerForSemanticSegmentation | |
try: | |
# Limit GPU memory growth | |
gpus = tf.config.experimental.list_physical_devices('GPU') | |
if gpus: | |
for gpu in gpus: | |
tf.config.experimental.set_memory_growth(gpu, True) | |
st.sidebar.success(f"GPU available: {len(gpus)} device(s)") | |
else: | |
st.sidebar.warning("No GPU detected, using CPU") | |
except Exception as e: | |
st.sidebar.error(f"GPU config error: {e}") | |
# Constants for image preprocessing | |
IMAGE_SIZE = 512 | |
OUTPUT_SIZE = 128 | |
MEAN = tf.constant([0.485, 0.456, 0.406]) | |
STD = tf.constant([0.229, 0.224, 0.225]) | |
# Class labels | |
ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"} | |
NUM_CLASSES = len(ID2LABEL) | |
def download_model_from_drive(): | |
# Create a models directory | |
os.makedirs("models", exist_ok=True) | |
model_path = "models/tf_model.h5" | |
if not os.path.exists(model_path): | |
# Extract the file ID from the sharing URL | |
file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3" | |
url = f"https://drive.google.com/uc?id={file_id}" | |
try: | |
gdown.download(url, model_path, quiet=False) | |
st.success("Model downloaded successfully from Google Drive.") | |
except Exception as e: | |
st.error(f"Failed to download model: {e}") | |
return None | |
else: | |
st.info("Model already exists locally.") | |
return model_path | |
def load_model(): | |
""" | |
Load the SegFormer model | |
Returns: | |
Loaded model | |
""" | |
try: | |
# First create a base model with the correct architecture | |
base_model = TFSegformerForSemanticSegmentation.from_pretrained( | |
"nvidia/mit-b0", | |
num_labels=NUM_CLASSES, | |
id2label=ID2LABEL, | |
label2id={label: id for id, label in ID2LABEL.items()}, | |
ignore_mismatched_sizes=True | |
) | |
# Download the trained weights | |
model_path = download_model_from_drive() | |
if model_path is not None and os.path.exists(model_path): | |
st.info(f"Loading weights from {model_path}...") | |
try: | |
# Try to load the weights | |
base_model.load_weights(model_path) | |
st.success("Model weights loaded successfully!") | |
return base_model | |
except Exception as e: | |
st.error(f"Error loading weights: {e}") | |
st.info("Using base pretrained model instead") | |
return base_model | |
else: | |
st.warning("Using base pretrained model since download failed") | |
return base_model | |
except Exception as e: | |
st.error(f"Error in load_model: {e}") | |
st.warning("Using default pretrained model") | |
# Fall back to pretrained model as a last resort | |
return TFSegformerForSemanticSegmentation.from_pretrained( | |
"nvidia/mit-b0", | |
num_labels=NUM_CLASSES, | |
id2label=ID2LABEL, | |
label2id={label: id for id, label in ID2LABEL.items()}, | |
ignore_mismatched_sizes=True | |
) | |
def normalize_image(input_image): | |
""" | |
Normalize the input image | |
Args: | |
input_image: Image to normalize | |
Returns: | |
Normalized image | |
""" | |
input_image = tf.image.convert_image_dtype(input_image, tf.float32) | |
input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon()) | |
return input_image | |
def preprocess_image(image, from_dataset=False): | |
""" | |
Preprocess image for model input with special handling for dataset images | |
Args: | |
image: PIL Image to preprocess | |
from_dataset: Whether the image is from the original dataset | |
Returns: | |
Preprocessed image tensor, original image | |
""" | |
# Convert PIL Image to numpy array | |
img_array = np.array(image.convert('RGB')) | |
# Store original image for display | |
original_img = img_array.copy() | |
# Resize to target size | |
img_resized = tf.image.resize( | |
img_array, | |
(IMAGE_SIZE, IMAGE_SIZE), | |
method='bilinear', | |
preserve_aspect_ratio=False, | |
antialias=True | |
) | |
# Special handling for dataset images | |
if from_dataset: | |
# The dataset already has specific dimensions, just normalize | |
# Skip additional preprocessing that might have been applied | |
img_normalized = normalize_image(img_resized) | |
else: | |
# Regular preprocessing for uploaded images | |
img_normalized = normalize_image(img_resized) | |
# Transpose from HWC to CHW (channels first) | |
img_transposed = tf.transpose(img_normalized, (2, 0, 1)) | |
# Add batch dimension | |
img_batch = tf.expand_dims(img_transposed, axis=0) | |
return img_batch, original_img | |
def create_mask(pred_mask): | |
""" | |
Convert model prediction to displayable mask | |
Args: | |
pred_mask: Prediction logits from the model | |
Returns: | |
Processed mask (2D array) | |
""" | |
# Take argmax along the class dimension (axis=1 for batch data) | |
pred_mask = tf.math.argmax(pred_mask, axis=1) | |
# Remove batch dimension and convert to numpy | |
pred_mask = tf.squeeze(pred_mask) | |
# Resize to match original image size if needed | |
if pred_mask.shape[0] != IMAGE_SIZE or pred_mask.shape[1] != IMAGE_SIZE: | |
pred_mask = tf.image.resize( | |
tf.expand_dims(pred_mask, axis=-1), | |
(IMAGE_SIZE, IMAGE_SIZE), | |
method='nearest' | |
) | |
pred_mask = tf.squeeze(pred_mask) | |
return pred_mask.numpy() | |
def colorize_mask(mask): | |
""" | |
Apply colors to segmentation mask | |
Args: | |
mask: Segmentation mask (2D array) | |
Returns: | |
Colorized mask (3D RGB array) | |
""" | |
# Ensure the mask is 2D | |
if len(mask.shape) > 2: | |
mask = np.squeeze(mask) | |
# Define colors for each class (RGB) | |
colors = [ | |
[0, 0, 0], # Background (black) | |
[255, 0, 0], # Border (red) | |
[0, 0, 255] # Foreground/pet (blue) | |
] | |
# Create RGB mask | |
rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) | |
for i, color in enumerate(colors): | |
class_mask = (mask == i).astype(np.uint8) | |
for c in range(3): | |
rgb_mask[:, :, c] += class_mask * color[c] | |
return rgb_mask | |
def calculate_iou(y_true, y_pred, class_idx=None): | |
""" | |
Calculate IoU (Intersection over Union) for segmentation masks | |
Args: | |
y_true: Ground truth segmentation mask | |
y_pred: Predicted segmentation mask | |
class_idx: Index of the class to calculate IoU for (None for mean IoU) | |
Returns: | |
IoU score | |
""" | |
if class_idx is not None: | |
# Binary IoU for specific class | |
y_true_class = (y_true == class_idx).astype(np.float32) | |
y_pred_class = (y_pred == class_idx).astype(np.float32) | |
intersection = np.sum(y_true_class * y_pred_class) | |
union = np.sum(y_true_class) + np.sum(y_pred_class) - intersection | |
iou = intersection / (union + 1e-6) | |
else: | |
# Mean IoU across all classes | |
class_ious = [] | |
for idx in range(NUM_CLASSES): | |
class_iou = calculate_iou(y_true, y_pred, idx) | |
class_ious.append(class_iou) | |
iou = np.mean(class_ious) | |
return iou | |
def create_overlay(image, mask, alpha=0.5): | |
""" | |
Create an overlay of mask on original image | |
Args: | |
image: Original image | |
mask: Segmentation mask | |
alpha: Transparency level (0-1) | |
Returns: | |
Overlay image | |
""" | |
# Ensure mask shape matches image | |
if image.shape[:2] != mask.shape[:2]: | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
# Create blend | |
overlay = cv2.addWeighted( | |
image, | |
1, | |
mask.astype(np.uint8), | |
alpha, | |
0 | |
) | |
return overlay | |
def display_results_side_by_side(original_image, ground_truth_mask=None, predicted_mask=None): | |
""" | |
Display results in a side-by-side format similar to colab_code.py | |
Args: | |
original_image: Original input image | |
ground_truth_mask: Optional ground truth segmentation mask | |
predicted_mask: Predicted segmentation mask | |
""" | |
# Determine how many images to display | |
cols = 1 + (ground_truth_mask is not None) + (predicted_mask is not None) | |
# Create a figure with multiple columns | |
st.write("### Segmentation Results Comparison") | |
col_list = st.columns(cols) | |
# Display original image | |
with col_list[0]: | |
st.markdown("**Original Image**") | |
st.image(original_image, use_column_width=True) | |
# Display ground truth if available | |
if ground_truth_mask is not None: | |
with col_list[1]: | |
st.markdown("**Ground Truth Mask**") | |
# Colorize ground truth if needed | |
if len(ground_truth_mask.shape) == 2: | |
gt_display = colorize_mask(ground_truth_mask) | |
else: | |
gt_display = ground_truth_mask | |
st.image(gt_display, use_column_width=True) | |
# Display prediction | |
if predicted_mask is not None: | |
with col_list[2 if ground_truth_mask is not None else 1]: | |
st.markdown("**Predicted Mask**") | |
# Colorize prediction if needed | |
if len(predicted_mask.shape) == 2: | |
pred_display = colorize_mask(predicted_mask) | |
else: | |
pred_display = predicted_mask | |
st.image(pred_display, use_column_width=True) | |
def process_uploaded_mask(mask_array, from_dataset=False): | |
""" | |
Process an uploaded mask to ensure it has the correct format | |
Args: | |
mask_array: Numpy array of the mask | |
from_dataset: Whether the mask is from the original dataset | |
Returns: | |
Processed mask with values 0,1,2 | |
""" | |
# Check for RGBA format and convert to RGB if needed | |
if len(mask_array.shape) == 3 and mask_array.shape[2] == 4: | |
# Convert RGBA to RGB (discard alpha channel) | |
mask_array = mask_array[:,:,:3] | |
# Convert RGB to grayscale if needed | |
if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3: | |
# Convert RGB to grayscale | |
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY) | |
if from_dataset: | |
# For dataset masks (saved from your colab code): | |
# Create an empty mask with the same shape | |
processed_mask = np.zeros_like(mask_array) | |
# Map the values correctly: | |
# Original dataset uses 1,2,3 but your app expects 0,1,2 | |
processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2) | |
processed_mask[mask_array == 2] = 1 # Border (2→1) | |
processed_mask[mask_array == 3] = 0 # Background (3→0) | |
return processed_mask | |
else: | |
# For non-dataset masks, we assume they have correct class values | |
return mask_array | |
def main(): | |
st.title("🐶 Pet Segmentation with SegFormer") | |
st.markdown(""" | |
This app demonstrates semantic segmentation of pet images using a SegFormer model. | |
The model segments images into three classes: | |
- **Background**: Areas around the pet | |
- **Border**: The boundary/outline around the pet | |
- **Foreground**: The pet itself | |
""") | |
# Sidebar | |
st.sidebar.header("Model Information") | |
st.sidebar.markdown(""" | |
**SegFormer** is a state-of-the-art semantic segmentation model based on transformers. | |
Key features: | |
- Hierarchical transformer encoder | |
- Lightweight MLP decoder | |
- Efficient mix of local and global attention | |
This implementation uses the MIT-B0 variant fine-tuned on the Oxford-IIIT Pet dataset. | |
""") | |
# Advanced settings in sidebar | |
st.sidebar.header("Settings") | |
# Overlay opacity | |
overlay_opacity = st.sidebar.slider( | |
"Overlay Opacity", | |
min_value=0.1, | |
max_value=1.0, | |
value=0.5, | |
step=0.1 | |
) | |
# Add this checkbox to your app's UI | |
dataset_image = st.sidebar.checkbox("Image is from the Oxford-IIIT Pet dataset") | |
# Load model | |
with st.spinner("Loading SegFormer model..."): | |
model = load_model() | |
if model is None: | |
st.error("Failed to load model. Using default pretrained model instead.") | |
else: | |
st.sidebar.success("Model loaded successfully!") | |
# Image upload section | |
st.header("Upload an Image") | |
uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"]) | |
uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"]) | |
# Process uploaded image | |
if uploaded_image is not None: | |
try: | |
# Read the image | |
image_bytes = uploaded_image.read() | |
image = Image.open(io.BytesIO(image_bytes)) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Original Image") | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Preprocess and predict | |
with st.spinner("Generating segmentation mask..."): | |
try: | |
# Preprocess the image | |
img_tensor, original_img = preprocess_image(image, from_dataset=dataset_image) | |
# Print shape to debug | |
st.write(f"DEBUG - Input tensor shape: {img_tensor.shape}") | |
# Make prediction with error handling | |
try: | |
outputs = model(pixel_values=img_tensor, training=False) | |
logits = outputs.logits | |
# Create visualization mask | |
mask = create_mask(logits) | |
# Colorize the mask | |
colorized_mask = colorize_mask(mask) | |
# Create overlay | |
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity) | |
except Exception as inference_error: | |
st.error(f"Inference error: {inference_error}") | |
st.write("Trying alternative approach...") | |
# Alternative: resize to exactly 512x512 with crop_or_pad | |
img_resized = tf.image.resize_with_crop_or_pad( | |
original_img, IMAGE_SIZE, IMAGE_SIZE | |
) | |
img_normalized = normalize_image(img_resized) | |
img_transposed = tf.transpose(img_normalized, (2, 0, 1)) | |
img_tensor = tf.expand_dims(img_transposed, axis=0) | |
outputs = model(pixel_values=img_tensor, training=False) | |
logits = outputs.logits | |
mask = create_mask(logits) | |
colorized_mask = colorize_mask(mask) | |
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity) | |
except Exception as e: | |
st.error(f"Failed to process image: {e}") | |
st.stop() | |
# Display results | |
with col2: | |
st.subheader("Segmentation Result") | |
st.image(overlay, caption="Segmentation Overlay", use_column_width=True) | |
# Display segmentation details | |
st.header("Segmentation Details") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.subheader("Background") | |
st.markdown("Areas surrounding the pet") | |
mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8) | |
st.image(mask_bg, caption="Background", use_column_width=True) | |
with col2: | |
st.subheader("Border") | |
st.markdown("Boundary around the pet") | |
mask_border = np.where(mask == 1, 255, 0).astype(np.uint8) | |
st.image(mask_border, caption="Border", use_column_width=True) | |
with col3: | |
st.subheader("Foreground (Pet)") | |
st.markdown("The pet itself") | |
mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8) | |
st.image(mask_fg, caption="Foreground", use_column_width=True) | |
# Calculate IoU if ground truth is uploaded | |
if uploaded_mask is not None: | |
try: | |
# Reset the file pointer to the beginning | |
uploaded_mask.seek(0) | |
# Read the mask file | |
mask_data = uploaded_mask.read() | |
mask_io = io.BytesIO(mask_data) | |
# Load the raw mask | |
raw_mask = np.array(Image.open(mask_io)) | |
# Show debug info | |
st.write(f"Debug - Raw mask shape: {raw_mask.shape}") | |
st.write(f"Debug - Raw mask unique values: {np.unique(raw_mask)}") | |
# Process the mask based on source | |
processed_gt_mask = process_uploaded_mask(raw_mask, from_dataset=dataset_image) | |
# Resize for IoU calculation | |
gt_mask_resized = cv2.resize(processed_gt_mask, (OUTPUT_SIZE, OUTPUT_SIZE), | |
interpolation=cv2.INTER_NEAREST) | |
# Resize prediction for comparison | |
pred_mask_resized = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), | |
interpolation=cv2.INTER_NEAREST) | |
# Show processed values | |
st.write(f"Debug - Processed GT mask unique values: {np.unique(gt_mask_resized)}") | |
st.write(f"Debug - Prediction mask unique values: {np.unique(pred_mask_resized)}") | |
# Calculate and display IoU | |
iou_score = calculate_iou(gt_mask_resized, pred_mask_resized) | |
st.success(f"Mean IoU: {iou_score:.4f}") | |
# Display specific class IoUs | |
st.markdown("### IoU by Class") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
bg_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 0) | |
st.metric("Background IoU", f"{bg_iou:.4f}") | |
with col2: | |
border_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 1) | |
st.metric("Border IoU", f"{border_iou:.4f}") | |
with col3: | |
fg_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 2) | |
st.metric("Foreground IoU", f"{fg_iou:.4f}") | |
# For display, create a colorized version of the ground truth | |
gt_mask_for_display = colorize_mask(processed_gt_mask) | |
# Side-by-side display | |
display_results_side_by_side( | |
original_img, | |
ground_truth_mask=gt_mask_for_display, | |
predicted_mask=colorized_mask | |
) | |
except Exception as e: | |
st.error(f"Error processing ground truth mask: {e}") | |
st.write("Please ensure the mask is valid and has the correct format.") | |
import traceback | |
st.code(traceback.format_exc()) # Show detailed error trace | |
# Even with an error, try to display results without the ground truth | |
display_results_side_by_side( | |
original_img, | |
ground_truth_mask=None, | |
predicted_mask=colorized_mask | |
) | |
else: | |
# No ground truth, just display original and prediction | |
display_results_side_by_side( | |
original_img, | |
ground_truth_mask=None, | |
predicted_mask=colorized_mask | |
) | |
# Download buttons | |
col1, col2 = st.columns(2) | |
with col1: | |
# Convert mask to PNG for download | |
mask_colored = Image.fromarray(colorized_mask) | |
mask_bytes = io.BytesIO() | |
mask_colored.save(mask_bytes, format='PNG') | |
mask_bytes = mask_bytes.getvalue() | |
st.download_button( | |
label="Download Segmentation Mask", | |
data=mask_bytes, | |
file_name="pet_segmentation_mask.png", | |
mime="image/png" | |
) | |
with col2: | |
# Convert overlay to PNG for download | |
overlay_img = Image.fromarray(overlay) | |
overlay_bytes = io.BytesIO() | |
overlay_img.save(overlay_bytes, format='PNG') | |
overlay_bytes = overlay_bytes.getvalue() | |
st.download_button( | |
label="Download Overlay Image", | |
data=overlay_bytes, | |
file_name="pet_segmentation_overlay.png", | |
mime="image/png" | |
) | |
except Exception as e: | |
st.error(f"Error processing image: {e}") | |
# Footer with additional information | |
st.markdown("---") | |
st.markdown("### About the Model") | |
st.markdown(""" | |
This segmentation model is based on the SegFormer architecture and was fine-tuned on the Oxford-IIIT Pet dataset. | |
**Key Performance Metrics:** | |
- Mean IoU (Intersection over Union): Measures overlap between predictions and ground truth | |
- Dice Coefficient: Similar to F1-score, balances precision and recall | |
The model segments pet images into three semantic classes (background, border, and pet/foreground), | |
making it useful for applications like pet image editing, background removal, and object detection. | |
""") | |
if __name__ == "__main__": | |
main() |