Spaces:
Paused
Paused
import torch | |
import numpy as np | |
import os | |
import torch.nn.functional as F | |
from PIL import Image | |
import time | |
import random | |
def generate_random_mask( | |
batch_size, | |
height=256, | |
width=256, | |
device='cuda', | |
min_coverage=0.2, | |
max_coverage=0.8, | |
num_blobs_range=(1, 3) | |
): | |
""" | |
Generate random blob masks for a batch of images. | |
Fast GPU version with smooth, non-circular blob shapes. | |
Args: | |
batch_size (int): Number of masks to generate | |
height (int): Height of the mask | |
width (int): Width of the mask | |
device (str): Device to run the computation on ('cuda' or 'cpu') | |
min_coverage (float): Minimum percentage of the image to be covered (0-1) | |
max_coverage (float): Maximum percentage of the image to be covered (0-1) | |
num_blobs_range (tuple): Range of number of blobs (min, max) | |
Returns: | |
torch.Tensor: Binary masks with shape (batch_size, 1, height, width) | |
""" | |
# Initialize masks on GPU | |
masks = torch.zeros((batch_size, 1, height, width), device=device) | |
# Pre-compute coordinate grid on GPU | |
y_indices = torch.arange(height, device=device).view( | |
height, 1).expand(height, width) | |
x_indices = torch.arange(width, device=device).view( | |
1, width).expand(height, width) | |
# Prepare gaussian kernels for smoothing | |
small_kernel = get_gaussian_kernel(7, 1.0).to(device) | |
small_kernel = small_kernel.view(1, 1, 7, 7) | |
large_kernel = get_gaussian_kernel(15, 2.5).to(device) | |
large_kernel = large_kernel.view(1, 1, 15, 15) | |
# Constants | |
max_radius = min(height, width) // 3 | |
min_radius = min(height, width) // 8 | |
# For each mask in the batch | |
for b in range(batch_size): | |
# Determine number of blobs for this mask | |
num_blobs = np.random.randint( | |
num_blobs_range[0], num_blobs_range[1] + 1) | |
# Target coverage for this mask | |
target_coverage = np.random.uniform(min_coverage, max_coverage) | |
# Initialize this mask | |
mask = torch.zeros(1, 1, height, width, device=device) | |
# Generate blobs with smoother edges | |
for _ in range(num_blobs): | |
# Create a low-frequency noise field first (for smooth organic shapes) | |
noise_field = torch.zeros(height, width, device=device) | |
# Use low-frequency sine waves to create base shape distortion | |
# This creates smoother warping compared to pure random noise | |
num_waves = np.random.randint(2, 5) | |
for i in range(num_waves): | |
freq_x = np.random.uniform(1.0, 3.0) * np.pi / width | |
freq_y = np.random.uniform(1.0, 3.0) * np.pi / height | |
phase_x = np.random.uniform(0, 2 * np.pi) | |
phase_y = np.random.uniform(0, 2 * np.pi) | |
amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5) | |
# Generate smooth wave patterns | |
wave = torch.sin(x_indices * freq_x + phase_x) * \ | |
torch.sin(y_indices * freq_y + phase_y) * amp | |
noise_field += wave | |
# Basic ellipse parameters | |
center_y = np.random.randint(height//4, 3*height//4) | |
center_x = np.random.randint(width//4, 3*width//4) | |
radius = np.random.randint(min_radius, max_radius) | |
# Squeeze and stretch the ellipse with random scaling | |
scale_y = np.random.uniform(0.6, 1.4) | |
scale_x = np.random.uniform(0.6, 1.4) | |
# Random rotation | |
theta = np.random.uniform(0, 2 * np.pi) | |
cos_theta, sin_theta = np.cos(theta), np.sin(theta) | |
# Calculate elliptical distance field | |
y_scaled = (y_indices - center_y) * scale_y | |
x_scaled = (x_indices - center_x) * scale_x | |
# Apply rotation | |
rotated_y = y_scaled * cos_theta - x_scaled * sin_theta | |
rotated_x = y_scaled * sin_theta + x_scaled * cos_theta | |
# Compute distances | |
distances = torch.sqrt(rotated_y**2 + rotated_x**2) | |
# Apply the smooth noise field to the distance field | |
perturbed_distances = distances + noise_field | |
# Create base blob | |
blob = (perturbed_distances < radius).float( | |
).unsqueeze(0).unsqueeze(0) | |
# Apply strong smoothing for very smooth edges | |
# Double smoothing to get really organic edges | |
blob = F.pad(blob, (7, 7, 7, 7), mode='reflect') | |
blob = F.conv2d(blob, large_kernel, padding=0) | |
# Apply threshold to get a nice shape | |
rand_threshold = np.random.uniform(0.3, 0.6) | |
blob = (blob > rand_threshold).float() | |
# Apply second smoothing pass | |
blob = F.pad(blob, (3, 3, 3, 3), mode='reflect') | |
blob = F.conv2d(blob, small_kernel, padding=0) | |
blob = (blob > 0.5).float() | |
# Add to mask | |
mask = torch.maximum(mask, blob) | |
# Ensure desired coverage | |
current_coverage = mask.mean().item() | |
# Scale if needed to match target coverage | |
if current_coverage > 0: # Avoid division by zero | |
if current_coverage < target_coverage * 0.7: # Too small | |
# Dilate mask to increase coverage | |
mask = F.pad(mask, (2, 2, 2, 2), mode='reflect') | |
mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0) | |
elif current_coverage > target_coverage * 1.3: # Too large | |
# Erode mask to decrease coverage | |
mask = F.pad(mask, (1, 1, 1, 1), mode='reflect') | |
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0) | |
mask = (mask > 0.7).float() | |
# Final smooth and threshold | |
mask = F.pad(mask, (3, 3, 3, 3), mode='reflect') | |
mask = F.conv2d(mask, small_kernel, padding=0) | |
mask = (mask > 0.5).float() | |
# Add to batch | |
masks[b] = mask | |
return masks | |
def get_gaussian_kernel(kernel_size=5, sigma=1.0): | |
""" | |
Returns a 2D Gaussian kernel. | |
""" | |
# Create 1D kernels | |
x = torch.linspace(-sigma * 2, sigma * 2, kernel_size) | |
x = x.view(1, -1).repeat(kernel_size, 1) | |
y = x.transpose(0, 1) | |
# 2D Gaussian | |
gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) | |
gaussian /= gaussian.sum() | |
return gaussian | |
def save_masks_as_images(masks, suffix="", output_dir="output"): | |
""" | |
Save generated masks as RGB JPG images using PIL. | |
""" | |
os.makedirs(output_dir, exist_ok=True) | |
batch_size = masks.shape[0] | |
for i in range(batch_size): | |
# Convert mask to numpy array | |
mask = masks[i, 0].cpu().numpy() | |
# Scale to 0-255 range and convert to uint8 | |
mask_255 = (mask * 255).astype(np.uint8) | |
# Create RGB image (white mask on black background) | |
rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2) | |
# Convert to PIL Image and save | |
img = Image.fromarray(rgb_mask) | |
img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95) | |
def random_dialate_mask(mask, max_percent=0.05): | |
""" | |
Randomly dialates a binary mask with a kernel of random size. | |
Args: | |
mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width] | |
max_percent (float): Maximum kernel size as a percentage of the mask size | |
Returns: | |
torch.Tensor: Dialated mask with the same shape as input | |
""" | |
size = mask.shape[-1] | |
max_size = int(size * max_percent) | |
# Handle case where max_size is too small | |
if max_size < 3: | |
max_size = 3 | |
batch_chunks = torch.chunk(mask, mask.shape[0], dim=0) | |
out_chunks = [] | |
for i in range(len(batch_chunks)): | |
chunk = batch_chunks[i] | |
# Ensure kernel size is odd for proper padding | |
kernel_size = np.random.randint(1, max_size) | |
# If kernel_size is less than 2, keep the original mask | |
if kernel_size < 2: | |
out_chunks.append(chunk) | |
continue | |
# Make sure kernel size is odd | |
if kernel_size % 2 == 0: | |
kernel_size += 1 | |
# Create normalized dilation kernel | |
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size) | |
# Pad the mask for convolution | |
padding = kernel_size // 2 | |
padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0) | |
# Apply convolution | |
dilated = F.conv2d(padded_mask, kernel) | |
# Random threshold for varied dilation effect | |
threshold = np.random.uniform(0.2, 0.8) | |
# Apply threshold | |
dilated = (dilated > threshold).float() | |
out_chunks.append(dilated) | |
return torch.cat(out_chunks, dim=0) | |
if __name__ == "__main__": | |
# Parameters | |
batch_size = 20 | |
height = 256 | |
width = 256 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Generating {batch_size} random blob masks on {device}...") | |
for i in range(5): | |
# time it | |
start = time.time() | |
masks = generate_random_mask( | |
batch_size=batch_size, | |
height=height, | |
width=width, | |
device=device, | |
min_coverage=0.2, | |
max_coverage=0.8, | |
num_blobs_range=(1, 3) | |
) | |
dialation = random_dialate_mask(masks) | |
print(f"Generated {batch_size} masks with shape: {masks.shape}") | |
end = time.time() | |
# print time in milliseconds | |
print(f"Time taken: {(end - start)*1000:.2f} ms") | |
print(f"Saving masks to 'output' directory...") | |
save_masks_as_images(masks) | |
save_masks_as_images(dialation, suffix="_dilated" ) | |
print("Done!") | |