ramimu's picture
Upload 586 files
1c72248 verified
import torch
import numpy as np
import os
import torch.nn.functional as F
from PIL import Image
import time
import random
def generate_random_mask(
batch_size,
height=256,
width=256,
device='cuda',
min_coverage=0.2,
max_coverage=0.8,
num_blobs_range=(1, 3)
):
"""
Generate random blob masks for a batch of images.
Fast GPU version with smooth, non-circular blob shapes.
Args:
batch_size (int): Number of masks to generate
height (int): Height of the mask
width (int): Width of the mask
device (str): Device to run the computation on ('cuda' or 'cpu')
min_coverage (float): Minimum percentage of the image to be covered (0-1)
max_coverage (float): Maximum percentage of the image to be covered (0-1)
num_blobs_range (tuple): Range of number of blobs (min, max)
Returns:
torch.Tensor: Binary masks with shape (batch_size, 1, height, width)
"""
# Initialize masks on GPU
masks = torch.zeros((batch_size, 1, height, width), device=device)
# Pre-compute coordinate grid on GPU
y_indices = torch.arange(height, device=device).view(
height, 1).expand(height, width)
x_indices = torch.arange(width, device=device).view(
1, width).expand(height, width)
# Prepare gaussian kernels for smoothing
small_kernel = get_gaussian_kernel(7, 1.0).to(device)
small_kernel = small_kernel.view(1, 1, 7, 7)
large_kernel = get_gaussian_kernel(15, 2.5).to(device)
large_kernel = large_kernel.view(1, 1, 15, 15)
# Constants
max_radius = min(height, width) // 3
min_radius = min(height, width) // 8
# For each mask in the batch
for b in range(batch_size):
# Determine number of blobs for this mask
num_blobs = np.random.randint(
num_blobs_range[0], num_blobs_range[1] + 1)
# Target coverage for this mask
target_coverage = np.random.uniform(min_coverage, max_coverage)
# Initialize this mask
mask = torch.zeros(1, 1, height, width, device=device)
# Generate blobs with smoother edges
for _ in range(num_blobs):
# Create a low-frequency noise field first (for smooth organic shapes)
noise_field = torch.zeros(height, width, device=device)
# Use low-frequency sine waves to create base shape distortion
# This creates smoother warping compared to pure random noise
num_waves = np.random.randint(2, 5)
for i in range(num_waves):
freq_x = np.random.uniform(1.0, 3.0) * np.pi / width
freq_y = np.random.uniform(1.0, 3.0) * np.pi / height
phase_x = np.random.uniform(0, 2 * np.pi)
phase_y = np.random.uniform(0, 2 * np.pi)
amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5)
# Generate smooth wave patterns
wave = torch.sin(x_indices * freq_x + phase_x) * \
torch.sin(y_indices * freq_y + phase_y) * amp
noise_field += wave
# Basic ellipse parameters
center_y = np.random.randint(height//4, 3*height//4)
center_x = np.random.randint(width//4, 3*width//4)
radius = np.random.randint(min_radius, max_radius)
# Squeeze and stretch the ellipse with random scaling
scale_y = np.random.uniform(0.6, 1.4)
scale_x = np.random.uniform(0.6, 1.4)
# Random rotation
theta = np.random.uniform(0, 2 * np.pi)
cos_theta, sin_theta = np.cos(theta), np.sin(theta)
# Calculate elliptical distance field
y_scaled = (y_indices - center_y) * scale_y
x_scaled = (x_indices - center_x) * scale_x
# Apply rotation
rotated_y = y_scaled * cos_theta - x_scaled * sin_theta
rotated_x = y_scaled * sin_theta + x_scaled * cos_theta
# Compute distances
distances = torch.sqrt(rotated_y**2 + rotated_x**2)
# Apply the smooth noise field to the distance field
perturbed_distances = distances + noise_field
# Create base blob
blob = (perturbed_distances < radius).float(
).unsqueeze(0).unsqueeze(0)
# Apply strong smoothing for very smooth edges
# Double smoothing to get really organic edges
blob = F.pad(blob, (7, 7, 7, 7), mode='reflect')
blob = F.conv2d(blob, large_kernel, padding=0)
# Apply threshold to get a nice shape
rand_threshold = np.random.uniform(0.3, 0.6)
blob = (blob > rand_threshold).float()
# Apply second smoothing pass
blob = F.pad(blob, (3, 3, 3, 3), mode='reflect')
blob = F.conv2d(blob, small_kernel, padding=0)
blob = (blob > 0.5).float()
# Add to mask
mask = torch.maximum(mask, blob)
# Ensure desired coverage
current_coverage = mask.mean().item()
# Scale if needed to match target coverage
if current_coverage > 0: # Avoid division by zero
if current_coverage < target_coverage * 0.7: # Too small
# Dilate mask to increase coverage
mask = F.pad(mask, (2, 2, 2, 2), mode='reflect')
mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0)
elif current_coverage > target_coverage * 1.3: # Too large
# Erode mask to decrease coverage
mask = F.pad(mask, (1, 1, 1, 1), mode='reflect')
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0)
mask = (mask > 0.7).float()
# Final smooth and threshold
mask = F.pad(mask, (3, 3, 3, 3), mode='reflect')
mask = F.conv2d(mask, small_kernel, padding=0)
mask = (mask > 0.5).float()
# Add to batch
masks[b] = mask
return masks
def get_gaussian_kernel(kernel_size=5, sigma=1.0):
"""
Returns a 2D Gaussian kernel.
"""
# Create 1D kernels
x = torch.linspace(-sigma * 2, sigma * 2, kernel_size)
x = x.view(1, -1).repeat(kernel_size, 1)
y = x.transpose(0, 1)
# 2D Gaussian
gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
gaussian /= gaussian.sum()
return gaussian
def save_masks_as_images(masks, suffix="", output_dir="output"):
"""
Save generated masks as RGB JPG images using PIL.
"""
os.makedirs(output_dir, exist_ok=True)
batch_size = masks.shape[0]
for i in range(batch_size):
# Convert mask to numpy array
mask = masks[i, 0].cpu().numpy()
# Scale to 0-255 range and convert to uint8
mask_255 = (mask * 255).astype(np.uint8)
# Create RGB image (white mask on black background)
rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2)
# Convert to PIL Image and save
img = Image.fromarray(rgb_mask)
img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95)
def random_dialate_mask(mask, max_percent=0.05):
"""
Randomly dialates a binary mask with a kernel of random size.
Args:
mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width]
max_percent (float): Maximum kernel size as a percentage of the mask size
Returns:
torch.Tensor: Dialated mask with the same shape as input
"""
size = mask.shape[-1]
max_size = int(size * max_percent)
# Handle case where max_size is too small
if max_size < 3:
max_size = 3
batch_chunks = torch.chunk(mask, mask.shape[0], dim=0)
out_chunks = []
for i in range(len(batch_chunks)):
chunk = batch_chunks[i]
# Ensure kernel size is odd for proper padding
kernel_size = np.random.randint(1, max_size)
# If kernel_size is less than 2, keep the original mask
if kernel_size < 2:
out_chunks.append(chunk)
continue
# Make sure kernel size is odd
if kernel_size % 2 == 0:
kernel_size += 1
# Create normalized dilation kernel
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size)
# Pad the mask for convolution
padding = kernel_size // 2
padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0)
# Apply convolution
dilated = F.conv2d(padded_mask, kernel)
# Random threshold for varied dilation effect
threshold = np.random.uniform(0.2, 0.8)
# Apply threshold
dilated = (dilated > threshold).float()
out_chunks.append(dilated)
return torch.cat(out_chunks, dim=0)
if __name__ == "__main__":
# Parameters
batch_size = 20
height = 256
width = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Generating {batch_size} random blob masks on {device}...")
for i in range(5):
# time it
start = time.time()
masks = generate_random_mask(
batch_size=batch_size,
height=height,
width=width,
device=device,
min_coverage=0.2,
max_coverage=0.8,
num_blobs_range=(1, 3)
)
dialation = random_dialate_mask(masks)
print(f"Generated {batch_size} masks with shape: {masks.shape}")
end = time.time()
# print time in milliseconds
print(f"Time taken: {(end - start)*1000:.2f} ms")
print(f"Saving masks to 'output' directory...")
save_masks_as_images(masks)
save_masks_as_images(dialation, suffix="_dilated" )
print("Done!")