Spaces:
Running
Running
# Importing necessary libraries | |
import io | |
import os | |
import cv2 | |
import utils | |
import random | |
import zipfile | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
import streamlit as st | |
import albumentations as A | |
import matplotlib.pyplot as plt | |
# Function to fetch the names of augmentation techniques applicable to the selected option | |
st.cache_resource(show_spinner=False) | |
def get_applicable_techniques(df, option): | |
return df[df[option] == "Applicable"]["Name"] | |
# Function to generate unique colors for each given list of unique numbers | |
def generate_unique_colors(unique_numbers): | |
# Get a color map | |
cmap = plt.get_cmap("hsv") | |
# Generate unique colors using the colormap | |
colors = { | |
number: cmap(i / len(unique_numbers))[:3] | |
for i, number in enumerate(unique_numbers) | |
} | |
# Convert colors from RGB to BGR format and scale to 0-255 | |
colors_bgr = { | |
k: (int(v[2] * 255), int(v[1] * 255), int(v[0] * 255)) | |
for k, v in colors.items() | |
} | |
return colors_bgr | |
# Function to adjust zero values to a small positive number | |
def adjust_zero_value(value): | |
return value if value != 0 else 1e-9 | |
# Function to parse a YOLO label file and convert it into Albumentations-compatible Bboxes format | |
def bboxes_label(label_file, class_dict): | |
bboxes_data = {"bboxes": [], "class_labels": []} | |
for line in label_file: | |
try: | |
# Extracting class_id and bounding box coordinates from each line | |
class_id, x_center, y_center, width, height = map( | |
float, line.decode().strip().split() | |
) | |
class_id += 1 # Shifting its starting value to 1, since 0 is reserved for the background | |
# Adjusting bounding box coordinates to avoid zero values | |
# Albumentations does not accept zero values, but they are acceptable in YOLO format | |
x_center, y_center, width, height = map( | |
adjust_zero_value, [x_center, y_center, width, height] | |
) | |
# Check if values are within the expected range and class_id exists in class_dict | |
if ( | |
0 <= x_center <= 1 | |
and 0 <= y_center <= 1 | |
and 0 <= width <= 1 | |
and 0 <= height <= 1 | |
and class_id in class_dict.keys() | |
): | |
bboxes_data["bboxes"].append([x_center, y_center, width, height]) | |
bboxes_data["class_labels"].append(class_id) | |
else: | |
return None # Return None if any value is out of range or class_id is invalid | |
except Exception as e: | |
# Return None if any exception is encountered | |
return None | |
# Return None if the file is empty or no valid data found | |
return bboxes_data if bboxes_data["bboxes"] else None | |
# Function to parse a YOLO label file and convert it into compatible Mask format | |
def masks_label(label_file, class_dict): | |
mask_data = {"masks": [], "class_labels": []} | |
for line in label_file: | |
try: | |
# Clean up the line and split into parts | |
parts = line.decode().strip().split() | |
class_id = ( | |
int(parts[0]) + 1 | |
) # Shifting its starting value to 1, since 0 is reserved for the background | |
points = [float(p) for p in parts[1:]] | |
# Check if class_id exists in class_dict and coordinates are within the expected range | |
if class_id in class_dict.keys() and all(0 <= p <= 1 for p in points): | |
# Group points into (x, y) tuples | |
polygon = [(points[i], points[i + 1]) for i in range(0, len(points), 2)] | |
# Append class label and polygon to the mask data | |
mask_data["class_labels"].append(class_id) | |
mask_data["masks"].append(polygon) | |
else: | |
return None # Return None if class_id is invalid or coordinates are out of range | |
except Exception as e: | |
# Return None if any exception is encountered | |
return None | |
# Return None if the file is empty or no valid data found | |
return mask_data if mask_data["masks"] else None | |
# Function to generate mask for albumentations format | |
def generate_mask(masks, class_ids, image_height, image_width): | |
# Create an empty mask of the same size as the image, filled with 0 for background | |
mask = np.full((image_height, image_width), 0, dtype=np.int32) | |
# Iterate over each polygon and its corresponding class_id | |
for polygon, class_id in zip(masks, class_ids): | |
# Scale the polygon points to the image size | |
scaled_polygon = [ | |
(int(x * image_width), int(y * image_height)) for x, y in polygon | |
] | |
# Draw the polygon on the mask | |
cv2.fillPoly(mask, [np.array(scaled_polygon, dtype=np.int32)], color=class_id) | |
return mask | |
# Function to convert a single-channel mask back to YOLO format | |
def mask_to_yolo(mask): | |
yolo_data = {"masks": [], "class_labels": []} | |
unique_values = np.unique(mask) | |
for value in unique_values: | |
if value == 0: # Skip the background | |
continue | |
# Extract mask for individual object | |
single_object_mask = np.uint8(mask == value) | |
# Find contours | |
contours, _ = cv2.findContours( | |
single_object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
) | |
for contour in contours: | |
# Normalize and flatten contour points | |
normalized_contour = [ | |
(point[0][0] / mask.shape[1], point[0][1] / mask.shape[0]) | |
for point in contour | |
] | |
yolo_data["masks"].append(normalized_contour) | |
yolo_data["class_labels"].append(value) | |
return yolo_data | |
# Function to create a user interface for adjusting bounding box parameters | |
def bbox_params(): | |
with st.expander("Bounding Box Parameters"): | |
# Create two columns for the input widgets | |
col1, col2 = st.columns(2) | |
with col1: | |
min_area = st.number_input( | |
"Minimum Area", | |
min_value=0, | |
max_value=1000, | |
value=0, | |
step=1, | |
help="Minimum area of a bounding box. Boxes smaller than this will be removed.", | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["bbox1_key"], | |
) | |
min_visibility = st.number_input( | |
"Minimum Visibility", | |
min_value=0, | |
max_value=1000, | |
value=0, | |
step=1, | |
help="Minimum fraction of area for a bounding box to remain in the list.", | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["bbox2_key"], | |
) | |
with col2: | |
min_width = st.number_input( | |
"Minimum Width", | |
min_value=0, | |
max_value=1000, | |
value=0, | |
step=1, | |
help="Minimum width of a bounding box. Boxes narrower than this will be removed.", | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["bbox3_key"], | |
) | |
min_height = st.number_input( | |
"Minimum Height", | |
min_value=0, | |
max_value=1000, | |
value=0, | |
step=1, | |
help="Minimum height of a bounding box. Boxes shorter than this will be removed.", | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["bbox4_key"], | |
) | |
check_each_transform = st.checkbox( | |
"Check Each Transform", | |
help="If checked, bounding boxes will be checked after each dual transform.", | |
value=True, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["bbox5_key"], | |
) | |
# Return the collected parameters as a dictionary | |
return { | |
"min_area": min_area, | |
"min_visibility": min_visibility, | |
"min_width": min_width, | |
"min_height": min_height, | |
"check_each_transform": check_each_transform, | |
} | |
# Function to check if the uploaded images and labels are valid | |
st.cache_resource(show_spinner=False) | |
def check_valid_labels(uploaded_files, selected_option, class_dict): | |
# Early exit if no files are uploaded | |
if len(uploaded_files) == 0: | |
st.warning("Please upload at least one image to apply augmentation.", icon="⚠️") | |
return False, {}, {}, None, None | |
# Initialize dictionaries to hold images and labels | |
image_files, label_files = {}, {} | |
# Extracting the name of the first file | |
first_file_name = os.path.splitext(uploaded_files[0].name)[0] | |
# Counters for images and labels | |
image_count, label_count = 0, 0 | |
# Initialize a progress bar and progress text | |
progress_bar = st.progress(0) | |
progress_text = st.empty() | |
total_files = len(uploaded_files) | |
# Categorize and prepare uploaded files | |
for index, file in enumerate(uploaded_files): | |
file.seek(0) # Reset file pointer to ensure proper file reading | |
file_name_without_extension = os.path.splitext(file.name)[0] | |
# Distribute files into image or label categories based on their file type | |
if file.type in ["image/jpeg", "image/png"]: | |
image_files[file_name_without_extension] = file | |
image_count += 1 | |
elif file.type == "text/plain": | |
file_content = file.readlines() | |
if selected_option == "Bboxes": | |
label_data = bboxes_label(file_content, class_dict) | |
elif selected_option == "Masks": | |
label_data = masks_label(file_content, class_dict) | |
# Check for valid label data | |
if label_data is None: | |
st.warning( | |
f"Invalid label format or data in file: {file.name}", | |
icon="⚠️", | |
) | |
return False, {}, {}, None, None | |
label_files[file_name_without_extension] = label_data | |
label_count += 1 | |
# Update progress bar and display current progress | |
progress_percentage = (index + 1) / total_files | |
progress_bar.progress(progress_percentage) | |
progress_text.text(f"Validating file {index + 1} of {total_files}") | |
# Extract sets of unique file names for images and labels | |
unique_image_names = set(image_files.keys()) | |
unique_label_names = set(label_files.keys()) | |
# Remove progress bar and progress text after processing | |
progress_bar.empty() | |
progress_text.empty() | |
if (len(unique_image_names) != image_count) or ( | |
len(unique_label_names) != label_count | |
): | |
# Warn the user about the presence of duplicate file names | |
st.warning( | |
"Duplicate file names detected. Please ensure each image and label has a unique name.", | |
icon="⚠️", | |
) | |
return False, {}, {}, None, None | |
# Perform validation checks | |
if (len(image_files) > 0) and (len(label_files) > 0): | |
# Check if the number of images and labels match and each pair has corresponding files | |
if (len(image_files) == len(label_files)) and ( | |
unique_image_names == unique_label_names | |
): | |
st.info( | |
f"Validated: {len(image_files)} images and labels successfully matched.", | |
icon="✅", | |
) | |
return ( | |
True, | |
image_files, | |
label_files, | |
image_files[first_file_name], | |
label_files[first_file_name], | |
) | |
elif len(image_files) != len(label_files): | |
# Warn if the count of images and labels does not match | |
st.warning( | |
"Count Mismatch: The number of uploaded images and labels does not match.", | |
icon="⚠️", | |
) | |
return False, {}, {}, None, None | |
else: | |
# Warn if there is a mismatch in file names between images and labels | |
st.warning( | |
"Mismatch detected: Some images do not have corresponding label files.", | |
icon="⚠️", | |
) | |
return False, {}, {}, None, None | |
elif len(image_files) > 0: | |
# Inform the user if only images are uploaded without labels | |
st.info( | |
f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.", | |
icon="✅", | |
) | |
return True, image_files, {}, image_files[first_file_name], None | |
else: | |
# Warn if no images are uploaded | |
st.warning("Please upload an image to apply augmentation.", icon="⚠️") | |
return False, {}, {}, None, None | |
# Function to apply an augmentation technique to an image and return any errors along with the processed image | |
def apply_and_test_augmentation( | |
augmentation, | |
params, | |
image, | |
label, | |
label_type, | |
label_input_parameters, | |
allowed_image_types, | |
): | |
try: | |
# Check the data type and number of channels of the input image | |
input_image_type = image.dtype | |
num_channels = ( | |
image.shape[2] if len(image.shape) == 3 else 1 | |
) # Assuming 1 for single-channel images | |
# Validate if the input image type is among the allowed types | |
if not utils.is_image_type_allowed( | |
input_image_type, num_channels, allowed_image_types | |
): | |
# Format the allowed types for display in the warning message | |
allowed_types_formatted = ", ".join(map(str, allowed_image_types)) | |
# Display a warning message specifying the acceptable image types | |
st.warning( | |
f"Error applying {augmentation}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}", | |
icon="⚠️", | |
) | |
return True, None # Error occurred | |
# Set the seed for reproducibility using iteration number | |
random.seed(0) | |
if label is None: | |
# Apply augmentation technique for no label input | |
transform = A.Compose([utils.apply_albumentation(params, augmentation)]) | |
processed_image = transform(image=image)["image"] | |
return False, processed_image | |
elif label_type == "Bboxes": | |
# Apply augmentation technique for Bboxes lable format | |
transform = A.Compose( | |
[utils.apply_albumentation(params, augmentation)], | |
bbox_params=A.BboxParams( | |
format="yolo", | |
label_fields=["class_labels"], | |
min_area=label_input_parameters["min_area"], | |
min_visibility=label_input_parameters["min_visibility"], | |
min_width=label_input_parameters["min_width"], | |
min_height=label_input_parameters["min_height"], | |
check_each_transform=label_input_parameters["check_each_transform"], | |
), | |
) | |
processed_image = transform( | |
image=image, | |
bboxes=label["bboxes"], | |
class_labels=label["class_labels"], | |
)["image"] | |
elif label_type == "Masks": | |
# Apply augmentation technique for Masks lable format | |
transform = A.Compose([utils.apply_albumentation(params, augmentation)]) | |
processed_image = transform( | |
image=image, | |
mask=generate_mask( | |
label["masks"], | |
label["class_labels"], | |
image.shape[0], | |
image.shape[1], | |
), | |
)["image"] | |
return False, processed_image # No error | |
except Exception as e: | |
st.warning(f"Error applying {augmentation}: {e}", icon="⚠️") | |
return True, None # Error occurred | |
# Generates a DataFrame detailing augmentation technique parameters and descriptions | |
def create_augmentations_dataframe(augmentations_params, augmentation_params_db): | |
data = [] | |
for aug_name, params in augmentations_params.items(): | |
for param_name, param_value in params.items(): | |
# Retrieve relevant augmentation information from the database | |
augmentation_info = augmentation_params_db[ | |
augmentation_params_db["Name"] == aug_name | |
] | |
param_info = augmentation_info[ | |
augmentation_info["Parameter Name"] == param_name | |
] | |
# Check if the parameter information exists in the database | |
if not param_info.empty: | |
# Get the description of the current parameter | |
param_description = param_info["Parameter Description"].iloc[0] | |
else: | |
param_description = "Description not available" | |
# Append augmentation name, parameter name, its value, and description to the data list | |
data.append([aug_name, param_name, param_value, param_description]) | |
# Create the DataFrame from the accumulated data | |
augmentations_df = pd.DataFrame( | |
data, columns=["augmentation", "Parameter", "Value", "Description"] | |
) | |
return augmentations_df | |
# Function to Generate Python Code for Augmentation with Bounding Box Labels | |
def generate_python_code_bboxes( | |
augmentations_params, | |
label_input_parameters, | |
num_variations=1, | |
include_original=False, | |
): | |
# Start with necessary library imports | |
code_str = "# Importing necessary libraries\n" | |
code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n" | |
# Paths for input and output directories | |
code_str += "# Define the paths for input and output directories\n" | |
code_str += "input_directory = 'path/to/input'\n" | |
code_str += "output_directory = 'path/to/output'\n\n" | |
# Function to read YOLO format labels | |
code_str += "# Function to read YOLO format labels\n" | |
code_str += "def read_yolo_label(label_path):\n" | |
code_str += " bboxes = []\n" | |
code_str += " class_ids = []\n" | |
code_str += " with open(label_path, 'r') as file:\n" | |
code_str += " for line in file:\n" | |
code_str += " class_id, x_center, y_center, width, height = map(float, line.split())\n" | |
code_str += " bboxes.append([x_center, y_center, width, height])\n" | |
code_str += " class_ids.append(int(class_id))\n" | |
code_str += " return bboxes, class_ids\n\n" | |
# Function to create an augmentation pipeline | |
code_str += "# Function to create an augmentation pipeline using Albumentations\n" | |
code_str += "def process_image(image, bboxes, class_ids):\n" | |
code_str += " # Define the sequence of augmentation techniques\n" | |
code_str += " pipeline = A.Compose([\n" | |
for technique, params in augmentations_params.items(): | |
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n" | |
code_str += " ], bbox_params=A.BboxParams(\n" | |
code_str += f" format='yolo',\n" | |
code_str += f" label_fields=['class_labels'],\n" | |
code_str += f" min_area={label_input_parameters['min_area']},\n" | |
code_str += f" min_visibility={label_input_parameters['min_visibility']},\n" | |
code_str += f" min_width={label_input_parameters['min_width']},\n" | |
code_str += f" min_height={label_input_parameters['min_height']},\n" | |
code_str += f" check_each_transform={label_input_parameters['check_each_transform']}\n" | |
code_str += " ))\n" | |
code_str += " # Apply the augmentation pipeline\n" | |
code_str += ( | |
" return pipeline(image=image, bboxes=bboxes, class_labels=class_ids)\n\n" | |
) | |
# Function to process a batch of images | |
code_str += "# Function to process a batch of images\n" | |
code_str += "def process_batch(input_directory, output_directory):\n" | |
code_str += " for filename in os.listdir(input_directory):\n" | |
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n" | |
code_str += " image_path = os.path.join(input_directory, filename)\n" | |
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n" | |
code_str += " # Read the image and label\n" | |
code_str += " image = cv2.imread(image_path)\n" | |
code_str += " bboxes, class_ids = read_yolo_label(label_path)\n\n" | |
# Include original image and label logic | |
code_str += " # Include original image and label\n" | |
if include_original: | |
code_str += " shutil.copy2(image_path, output_directory)\n" | |
code_str += " shutil.copy2(label_path, output_directory)\n\n" | |
# Generate variations for each image and process them | |
code_str += " # Generate variations for each image and process them\n" | |
code_str += f" for variation in range({num_variations}):\n" | |
code_str += ( | |
" processed_data = process_image(image, bboxes, class_ids)\n" | |
) | |
code_str += " processed_image = processed_data['image']\n" | |
code_str += " processed_bboxes = processed_data['bboxes']\n" | |
code_str += ( | |
" processed_class_ids = processed_data['class_labels']\n\n" | |
) | |
code_str += " # Save the processed image\n" | |
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n" | |
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n" | |
code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n" | |
code_str += " for bbox, class_id in zip(processed_bboxes, processed_class_ids):\n" | |
code_str += " label_line = ' '.join(map(str, [class_id] + list(bbox)))\n" | |
code_str += " label_file.write(label_line + '\\n')\n\n" | |
# Execute the batch processing function | |
code_str += ( | |
"# Execute the batch processing function with the specified parameters\n" | |
) | |
code_str += f"process_batch(input_directory, output_directory)\n" | |
return code_str | |
def generate_python_code_masks( | |
augmentations_params, | |
label_input_parameters, | |
num_variations=1, | |
include_original=False, | |
): | |
# Start with necessary library imports | |
code_str = "# Importing necessary libraries\n" | |
code_str += "import os\nimport cv2\nimport shutil\nimport numpy as np\nimport albumentations as A\n\n" | |
# Paths for input and output directories | |
code_str += "# Define the paths for input and output directories\n" | |
code_str += "input_directory = 'path/to/input'\n" | |
code_str += "output_directory = 'path/to/output'\n\n" | |
# Function to read YOLO mask format and convert to mask | |
code_str += "# Function to read YOLO mask format and convert to mask\n" | |
code_str += "def read_yolo_label(label_path, image_shape):\n" | |
code_str += " mask = np.full(image_shape, -1, dtype=np.int32)\n" | |
code_str += " with open(label_path, 'r') as file:\n" | |
code_str += " for line in file:\n" | |
code_str += " parts = line.strip().split()\n" | |
code_str += " class_id = int(parts[0])\n" | |
code_str += ( | |
" points = np.array([float(p) for p in parts[1:]]).reshape(-1, 2)\n" | |
) | |
code_str += " points = (points * [image_shape[1], image_shape[0]]).astype(np.int32)\n" | |
code_str += " cv2.fillPoly(mask, [points], class_id)\n" | |
code_str += " return mask\n\n" | |
# Function to convert mask to YOLO format | |
code_str += "# Function to convert mask to YOLO format\n" | |
code_str += "def mask_to_yolo(mask):\n" | |
code_str += " yolo_format = ''\n" | |
code_str += " for class_id in np.unique(mask):\n" | |
code_str += " if class_id == -1:\n" | |
code_str += " continue\n" | |
code_str += " contours, _ = cv2.findContours(\n" | |
code_str += " np.uint8(mask == class_id), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n" | |
code_str += " )\n" | |
code_str += " for contour in contours:\n" | |
code_str += " contour = contour.flatten().tolist()\n" | |
code_str += " normalized_contour = [\n" | |
code_str += " str(coord / mask.shape[i % 2]) for i, coord in enumerate(contour)\n" | |
code_str += " ]\n" | |
code_str += " yolo_format += f'{class_id} ' + ' '.join(normalized_contour) + '\\n'\n" | |
code_str += " return yolo_format\n\n" | |
# Function to create an augmentation pipeline | |
code_str += "# Function to create an augmentation pipeline using Albumentations\n" | |
code_str += "def process_image(image, mask):\n" | |
code_str += " # Define the sequence of augmentation techniques\n" | |
code_str += " pipeline = A.Compose([\n" | |
for technique, params in augmentations_params.items(): | |
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n" | |
code_str += " ])\n" | |
code_str += " # Apply the augmentation pipeline\n" | |
code_str += " return pipeline(image=image, mask=mask)\n\n" | |
# Function to process a batch of images | |
code_str += "# Function to process a batch of images\n" | |
code_str += "def process_batch(input_directory, output_directory, include_original=False, num_variations=1):\n" | |
code_str += " for filename in os.listdir(input_directory):\n" | |
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n" | |
code_str += " image_path = os.path.join(input_directory, filename)\n" | |
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n" | |
code_str += " # Read the image\n" | |
code_str += " image = cv2.imread(image_path)\n\n" | |
# Include original image and label logic | |
code_str += " # Include original image\n" | |
if include_original: | |
code_str += " shutil.copy2(image_path, output_directory)\n" | |
code_str += " shutil.copy2(label_path, output_directory)\n\n" | |
code_str += " # Check if label file exists and read mask\n" | |
code_str += " mask = None\n" | |
code_str += " if os.path.exists(label_path):\n" | |
code_str += ( | |
" mask = read_yolo_label(label_path, image.shape[:2])\n\n" | |
) | |
# Generate variations for each image and process them | |
code_str += " # Generate variations for each image and process them\n" | |
code_str += f" for variation in range({num_variations}):\n" | |
code_str += " processed_image, processed_mask = image, mask\n" | |
code_str += " if mask is not None:\n" | |
code_str += " processed_data = process_image(image, mask)\n" | |
code_str += " processed_image, processed_mask = processed_data['image'], processed_data['mask']\n\n" | |
code_str += " # Save the processed image\n" | |
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}.jpg'\n" | |
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n" | |
code_str += " # Save the processed label in YOLO format\n" | |
code_str += " if processed_mask is not None:\n" | |
code_str += ( | |
" processed_label_str = mask_to_yolo(processed_mask)\n" | |
) | |
code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n" | |
code_str += " label_file.write(processed_label_str)\n\n" | |
# Execute the batch processing function with the specified parameters | |
code_str += ( | |
"# Execute the batch processing function with the specified parameters\n" | |
) | |
code_str += "process_batch(input_directory, output_directory)\n" | |
return code_str | |
# Function to create an augmentation pipeline based on the selected techniques and their parameters | |
def create_augmentation_pipeline( | |
selected_augmentations, augmentation_params, label_type, label_input_parameters=None | |
): | |
pipeline = [] | |
for aug_name in selected_augmentations: | |
# Append the function call with its parameters to the pipeline | |
pipeline.append( | |
utils.apply_albumentation(augmentation_params[aug_name], aug_name) | |
) | |
# Compose all the augmentations into one transformation | |
try: | |
# Set the seed for reproducibility using iteration number | |
random.seed(0) | |
if label_type is None: | |
# Apply augmentation technique for no label input | |
transform = A.Compose(pipeline) | |
return transform | |
elif label_type == "Bboxes": | |
# Apply augmentation technique for Bboxes lable format | |
transform = A.Compose( | |
pipeline, | |
bbox_params=A.BboxParams( | |
format="yolo", | |
label_fields=["class_labels"], | |
min_area=label_input_parameters["min_area"], | |
min_visibility=label_input_parameters["min_visibility"], | |
min_width=label_input_parameters["min_width"], | |
min_height=label_input_parameters["min_height"], | |
check_each_transform=label_input_parameters["check_each_transform"], | |
), | |
) | |
elif label_type == "Masks": | |
# Apply augmentation technique for Masks lable format | |
transform = A.Compose(pipeline) | |
return transform # No error | |
except Exception as e: | |
st.warning(f"Error applying augmentation") | |
return None # Error occurred | |
# Function to convert label data from dictionary format to YOLO format | |
def convert_labels_to_yolo_format(label_data, class_dict): | |
yolo_label_str = "" | |
# Convert bounding boxes to YOLO format | |
if "bboxes" in label_data: | |
for bbox, class_label in zip(label_data["bboxes"], label_data["class_labels"]): | |
class_id = class_label - 1 # Revert to the original value | |
x_center, y_center, width, height = bbox | |
yolo_label_str += f"{class_id} {x_center} {y_center} {width} {height}\n" | |
# Convert masks to YOLO format | |
if "masks" in label_data: | |
for mask, class_label in zip(label_data["masks"], label_data["class_labels"]): | |
class_id = class_label - 1 # Revert to the original value | |
# Flatten the mask array into a single line of coordinates | |
mask_flattened = [coord for point in mask for coord in point] | |
mask_str = " ".join(map(str, mask_flattened)) | |
yolo_label_str += f"{class_id} {mask_str}\n" | |
return yolo_label_str | |
# Function to apply the augmentation pipeline to the image based on the label type | |
def apply_augmentation_pipeline(image, label_file, label_type, transform): | |
# Initialize an empty dictionary to store processed labels | |
processed_label = {} | |
# Apply the transformation based on the label type | |
if label_type is None: | |
processed_output = transform(image=image) | |
processed_label = None | |
elif label_type == "Bboxes": | |
processed_output = transform( | |
image=image, | |
bboxes=label_file["bboxes"], | |
class_labels=label_file["class_labels"], | |
) | |
processed_label = { | |
"bboxes": processed_output["bboxes"], | |
"class_labels": processed_output["class_labels"], | |
} | |
elif label_type == "Masks": | |
mask = generate_mask( | |
label_file["masks"], | |
label_file["class_labels"], | |
image.shape[0], | |
image.shape[1], | |
) | |
processed_output = transform(image=image, mask=mask) | |
mask_yolo = mask_to_yolo(processed_output["mask"]) | |
processed_label = mask_yolo | |
# Extract the processed image | |
processed_image = processed_output["image"] | |
return processed_image, processed_label | |
# Function to process images and labels, apply augmentations, and create a zip file with the results | |
def process_images_and_labels( | |
image_files, | |
label_files, | |
selected_augmentations, | |
_augmentations_params, | |
label_type, | |
label_input_parameters, | |
num_variations, | |
include_original, | |
class_dict, | |
): | |
zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file | |
st.session_state[ | |
"image_repository_augmentation" | |
] = {} # Initialize a repository to store processed image data | |
st.session_state[ | |
"processed_image_mapping_augmentation" | |
] = {} # Map original images to their processed versions | |
st.session_state["unique_images_names"] = [] # List to store unique images names | |
# Create progress bar and text elements in Streamlit | |
progress_bar = st.progress(0) | |
progress_text = st.empty() | |
with zipfile.ZipFile( | |
zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True | |
) as zip_file: | |
# Determine the label type for augmentation, if label files are present | |
effective_label_type = None if len(label_files) == 0 else label_type | |
# Create an augmentation pipeline based on selected augmentations and parameters | |
transform = create_augmentation_pipeline( | |
selected_augmentations, | |
_augmentations_params, | |
effective_label_type, | |
label_input_parameters, | |
) | |
total_images = len(image_files) * num_variations | |
processed_count = 0 # Counter for processed images | |
# Iterate over each uploaded file | |
for image_name, image_file in image_files.items(): | |
image_file.seek(0) # Reset file pointer to start | |
file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8) | |
original_image = cv2.cvtColor( | |
cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB | |
) | |
original_image_resized = utils.resize_image(original_image) | |
# Include original images and labels in the output if selected | |
if include_original: | |
original_img_buffer = io.BytesIO() | |
Image.fromarray(original_image).save(original_img_buffer, format="JPEG") | |
zip_file.writestr(image_file.name, original_img_buffer.getvalue()) | |
# Convert and save original labels to YOLO format if they exist | |
label_file = label_files.get(image_name) | |
if label_file is not None: | |
yolo_label_str = convert_labels_to_yolo_format( | |
label_file, class_dict | |
) | |
zip_file.writestr(f"{image_name}.txt", yolo_label_str) | |
original_file_name = image_file.name | |
st.session_state["unique_images_names"].append(original_file_name) | |
st.session_state["processed_image_mapping_augmentation"][ | |
original_file_name | |
] = [] | |
st.session_state["image_repository_augmentation"][image_file.name] = { | |
"image": original_image_resized, | |
"label": label_files.get(image_name), | |
} | |
# Apply augmentations and generate variations | |
for i in range(num_variations): | |
random.seed(i) | |
( | |
processed_image, | |
processed_label, | |
) = apply_augmentation_pipeline( | |
original_image, | |
label_files.get(image_name), | |
effective_label_type, | |
transform, | |
) | |
img_buffer = io.BytesIO() | |
Image.fromarray(processed_image).save(img_buffer, format="JPEG") | |
processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg" | |
zip_file.writestr(processed_filename, img_buffer.getvalue()) | |
processed_image_resized = utils.resize_image(processed_image) | |
st.session_state["processed_image_mapping_augmentation"][ | |
image_file.name | |
].append(processed_filename) | |
st.session_state["image_repository_augmentation"][ | |
processed_filename | |
] = { | |
"image": processed_image_resized, | |
"label": processed_label, | |
} | |
# Convert and save processed labels to YOLO format if they exist | |
label_file = label_files.get(image_name) | |
if label_file is not None: | |
processed_label_str = convert_labels_to_yolo_format( | |
processed_label, class_dict | |
) | |
zip_file.writestr( | |
f"processed_{image_name.split('.')[0]}_{i}.txt", | |
processed_label_str, | |
) | |
processed_count += 1 | |
# Update progress bar and text | |
progress_bar.progress(processed_count / total_images) | |
progress_text.text( | |
f"Processing image {processed_count} of {total_images}" | |
) | |
# Remove the progress bar and text after processing is complete | |
progress_bar.empty() | |
progress_text.empty() | |
zip_buffer.seek(0) # Reset buffer to start for download | |
st.session_state["zip_data_augmentation"] = zip_buffer.getvalue() | |
# Function to overlay labels on images | |
def overlay_labels(image, labels_to_plot, label_file, label_type, colors, class_dict): | |
# Ensure the image is in the correct format (RGB) | |
if len(image.shape) == 2 or image.shape[2] == 1: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
# Overlay Bounding Boxes | |
if label_type == "Bboxes": | |
for bbox, label in zip(label_file["bboxes"], label_file["class_labels"]): | |
if label in labels_to_plot: | |
# Convert bbox from yolo format to xmin, ymin, xmax, ymax | |
x_center, y_center, width, height = bbox | |
xmin = int((x_center - width / 2) * image.shape[1]) | |
xmax = int((x_center + width / 2) * image.shape[1]) | |
ymin = int((y_center - height / 2) * image.shape[0]) | |
ymax = int((y_center + height / 2) * image.shape[0]) | |
# Get color for the class | |
color = colors[label] | |
# Draw rectangle and label | |
image = cv2.rectangle( | |
image, (xmin, ymin), (xmax, ymax), color, thickness=2 | |
) | |
# Put class label text | |
label_text = class_dict.get(label, "Unknown") | |
cv2.putText( | |
image, | |
label_text, | |
(xmin, ymin - 5), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
(0, 0, 0), # Black color for text | |
2, | |
) | |
# Overlay Mask | |
elif label_type == "Masks": | |
for polygon, label in zip(label_file["masks"], label_file["class_labels"]): | |
if label in labels_to_plot: | |
# Convert polygon points from yolo format to image coordinates | |
polygon_points = [ | |
(int(x * image.shape[1]), int(y * image.shape[0])) | |
for x, y in polygon | |
] | |
# Get color for the class | |
color = colors[label] | |
# Create a temporary image to draw the polygon | |
temp_image = image.copy() | |
cv2.fillPoly( | |
temp_image, [np.array(polygon_points, dtype=np.int32)], color | |
) | |
# Blend the temporary image with the original image | |
cv2.addWeighted(temp_image, 0.5, image, 0.5, 0, image) | |
# Optional: Put class label text near the first point of the polygon | |
label_text = class_dict.get(label, "Unknown") | |
cv2.putText( | |
image, | |
label_text, | |
polygon_points[0], | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
(0, 0, 0), # Black color for text | |
2, | |
) | |
return image | |
# Function to generate a downloadable file | |
def display_code_and_download_button(generated_code): | |
def generate_downloadable_file(code_str): | |
return code_str.encode("utf-8") | |
# Display the generated code in Streamlit with description and download button in columns | |
with st.expander("Plug and Play Code"): | |
col1, col2 = st.columns([7, 3]) | |
with col1: | |
st.markdown( | |
""" | |
### Description of the Code Pipeline | |
""" | |
) | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory. | |
**To use this script:** | |
- Ensure you have the necessary dependencies installed. | |
- Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images. | |
- The number of augmented variations per image, the inclusion of the original images in the output, and the augmentation techniques with their parameters will be automatically set based on your selections. | |
### Python Code | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Display python code | |
st.code(generated_code, language="python") | |
with col2: | |
# Create a button for downloading the Python file | |
st.download_button( | |
label="Download Python File", | |
data=generate_downloadable_file(generated_code), | |
file_name="augmentation_script.py", | |
mime="text/plain", | |
use_container_width=True, | |
) | |