CV_Accelerator / Functions /image_augmentation_functions.py
samkeet's picture
First Commit
3d90a2e verified
# Importing necessary libraries
import io
import os
import cv2
import utils
import random
import zipfile
import numpy as np
import pandas as pd
from PIL import Image
import streamlit as st
import albumentations as A
import matplotlib.pyplot as plt
# Function to fetch the names of augmentation techniques applicable to the selected option
st.cache_resource(show_spinner=False)
def get_applicable_techniques(df, option):
return df[df[option] == "Applicable"]["Name"]
# Function to generate unique colors for each given list of unique numbers
def generate_unique_colors(unique_numbers):
# Get a color map
cmap = plt.get_cmap("hsv")
# Generate unique colors using the colormap
colors = {
number: cmap(i / len(unique_numbers))[:3]
for i, number in enumerate(unique_numbers)
}
# Convert colors from RGB to BGR format and scale to 0-255
colors_bgr = {
k: (int(v[2] * 255), int(v[1] * 255), int(v[0] * 255))
for k, v in colors.items()
}
return colors_bgr
# Function to adjust zero values to a small positive number
def adjust_zero_value(value):
return value if value != 0 else 1e-9
# Function to parse a YOLO label file and convert it into Albumentations-compatible Bboxes format
def bboxes_label(label_file, class_dict):
bboxes_data = {"bboxes": [], "class_labels": []}
for line in label_file:
try:
# Extracting class_id and bounding box coordinates from each line
class_id, x_center, y_center, width, height = map(
float, line.decode().strip().split()
)
class_id += 1 # Shifting its starting value to 1, since 0 is reserved for the background
# Adjusting bounding box coordinates to avoid zero values
# Albumentations does not accept zero values, but they are acceptable in YOLO format
x_center, y_center, width, height = map(
adjust_zero_value, [x_center, y_center, width, height]
)
# Check if values are within the expected range and class_id exists in class_dict
if (
0 <= x_center <= 1
and 0 <= y_center <= 1
and 0 <= width <= 1
and 0 <= height <= 1
and class_id in class_dict.keys()
):
bboxes_data["bboxes"].append([x_center, y_center, width, height])
bboxes_data["class_labels"].append(class_id)
else:
return None # Return None if any value is out of range or class_id is invalid
except Exception as e:
# Return None if any exception is encountered
return None
# Return None if the file is empty or no valid data found
return bboxes_data if bboxes_data["bboxes"] else None
# Function to parse a YOLO label file and convert it into compatible Mask format
def masks_label(label_file, class_dict):
mask_data = {"masks": [], "class_labels": []}
for line in label_file:
try:
# Clean up the line and split into parts
parts = line.decode().strip().split()
class_id = (
int(parts[0]) + 1
) # Shifting its starting value to 1, since 0 is reserved for the background
points = [float(p) for p in parts[1:]]
# Check if class_id exists in class_dict and coordinates are within the expected range
if class_id in class_dict.keys() and all(0 <= p <= 1 for p in points):
# Group points into (x, y) tuples
polygon = [(points[i], points[i + 1]) for i in range(0, len(points), 2)]
# Append class label and polygon to the mask data
mask_data["class_labels"].append(class_id)
mask_data["masks"].append(polygon)
else:
return None # Return None if class_id is invalid or coordinates are out of range
except Exception as e:
# Return None if any exception is encountered
return None
# Return None if the file is empty or no valid data found
return mask_data if mask_data["masks"] else None
# Function to generate mask for albumentations format
def generate_mask(masks, class_ids, image_height, image_width):
# Create an empty mask of the same size as the image, filled with 0 for background
mask = np.full((image_height, image_width), 0, dtype=np.int32)
# Iterate over each polygon and its corresponding class_id
for polygon, class_id in zip(masks, class_ids):
# Scale the polygon points to the image size
scaled_polygon = [
(int(x * image_width), int(y * image_height)) for x, y in polygon
]
# Draw the polygon on the mask
cv2.fillPoly(mask, [np.array(scaled_polygon, dtype=np.int32)], color=class_id)
return mask
# Function to convert a single-channel mask back to YOLO format
def mask_to_yolo(mask):
yolo_data = {"masks": [], "class_labels": []}
unique_values = np.unique(mask)
for value in unique_values:
if value == 0: # Skip the background
continue
# Extract mask for individual object
single_object_mask = np.uint8(mask == value)
# Find contours
contours, _ = cv2.findContours(
single_object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours:
# Normalize and flatten contour points
normalized_contour = [
(point[0][0] / mask.shape[1], point[0][1] / mask.shape[0])
for point in contour
]
yolo_data["masks"].append(normalized_contour)
yolo_data["class_labels"].append(value)
return yolo_data
# Function to create a user interface for adjusting bounding box parameters
def bbox_params():
with st.expander("Bounding Box Parameters"):
# Create two columns for the input widgets
col1, col2 = st.columns(2)
with col1:
min_area = st.number_input(
"Minimum Area",
min_value=0,
max_value=1000,
value=0,
step=1,
help="Minimum area of a bounding box. Boxes smaller than this will be removed.",
on_change=utils.reset_validation_trigger,
key=st.session_state["bbox1_key"],
)
min_visibility = st.number_input(
"Minimum Visibility",
min_value=0,
max_value=1000,
value=0,
step=1,
help="Minimum fraction of area for a bounding box to remain in the list.",
on_change=utils.reset_validation_trigger,
key=st.session_state["bbox2_key"],
)
with col2:
min_width = st.number_input(
"Minimum Width",
min_value=0,
max_value=1000,
value=0,
step=1,
help="Minimum width of a bounding box. Boxes narrower than this will be removed.",
on_change=utils.reset_validation_trigger,
key=st.session_state["bbox3_key"],
)
min_height = st.number_input(
"Minimum Height",
min_value=0,
max_value=1000,
value=0,
step=1,
help="Minimum height of a bounding box. Boxes shorter than this will be removed.",
on_change=utils.reset_validation_trigger,
key=st.session_state["bbox4_key"],
)
check_each_transform = st.checkbox(
"Check Each Transform",
help="If checked, bounding boxes will be checked after each dual transform.",
value=True,
on_change=utils.reset_validation_trigger,
key=st.session_state["bbox5_key"],
)
# Return the collected parameters as a dictionary
return {
"min_area": min_area,
"min_visibility": min_visibility,
"min_width": min_width,
"min_height": min_height,
"check_each_transform": check_each_transform,
}
# Function to check if the uploaded images and labels are valid
st.cache_resource(show_spinner=False)
def check_valid_labels(uploaded_files, selected_option, class_dict):
# Early exit if no files are uploaded
if len(uploaded_files) == 0:
st.warning("Please upload at least one image to apply augmentation.", icon="⚠️")
return False, {}, {}, None, None
# Initialize dictionaries to hold images and labels
image_files, label_files = {}, {}
# Extracting the name of the first file
first_file_name = os.path.splitext(uploaded_files[0].name)[0]
# Counters for images and labels
image_count, label_count = 0, 0
# Initialize a progress bar and progress text
progress_bar = st.progress(0)
progress_text = st.empty()
total_files = len(uploaded_files)
# Categorize and prepare uploaded files
for index, file in enumerate(uploaded_files):
file.seek(0) # Reset file pointer to ensure proper file reading
file_name_without_extension = os.path.splitext(file.name)[0]
# Distribute files into image or label categories based on their file type
if file.type in ["image/jpeg", "image/png"]:
image_files[file_name_without_extension] = file
image_count += 1
elif file.type == "text/plain":
file_content = file.readlines()
if selected_option == "Bboxes":
label_data = bboxes_label(file_content, class_dict)
elif selected_option == "Masks":
label_data = masks_label(file_content, class_dict)
# Check for valid label data
if label_data is None:
st.warning(
f"Invalid label format or data in file: {file.name}",
icon="⚠️",
)
return False, {}, {}, None, None
label_files[file_name_without_extension] = label_data
label_count += 1
# Update progress bar and display current progress
progress_percentage = (index + 1) / total_files
progress_bar.progress(progress_percentage)
progress_text.text(f"Validating file {index + 1} of {total_files}")
# Extract sets of unique file names for images and labels
unique_image_names = set(image_files.keys())
unique_label_names = set(label_files.keys())
# Remove progress bar and progress text after processing
progress_bar.empty()
progress_text.empty()
if (len(unique_image_names) != image_count) or (
len(unique_label_names) != label_count
):
# Warn the user about the presence of duplicate file names
st.warning(
"Duplicate file names detected. Please ensure each image and label has a unique name.",
icon="⚠️",
)
return False, {}, {}, None, None
# Perform validation checks
if (len(image_files) > 0) and (len(label_files) > 0):
# Check if the number of images and labels match and each pair has corresponding files
if (len(image_files) == len(label_files)) and (
unique_image_names == unique_label_names
):
st.info(
f"Validated: {len(image_files)} images and labels successfully matched.",
icon="✅",
)
return (
True,
image_files,
label_files,
image_files[first_file_name],
label_files[first_file_name],
)
elif len(image_files) != len(label_files):
# Warn if the count of images and labels does not match
st.warning(
"Count Mismatch: The number of uploaded images and labels does not match.",
icon="⚠️",
)
return False, {}, {}, None, None
else:
# Warn if there is a mismatch in file names between images and labels
st.warning(
"Mismatch detected: Some images do not have corresponding label files.",
icon="⚠️",
)
return False, {}, {}, None, None
elif len(image_files) > 0:
# Inform the user if only images are uploaded without labels
st.info(
f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.",
icon="✅",
)
return True, image_files, {}, image_files[first_file_name], None
else:
# Warn if no images are uploaded
st.warning("Please upload an image to apply augmentation.", icon="⚠️")
return False, {}, {}, None, None
# Function to apply an augmentation technique to an image and return any errors along with the processed image
def apply_and_test_augmentation(
augmentation,
params,
image,
label,
label_type,
label_input_parameters,
allowed_image_types,
):
try:
# Check the data type and number of channels of the input image
input_image_type = image.dtype
num_channels = (
image.shape[2] if len(image.shape) == 3 else 1
) # Assuming 1 for single-channel images
# Validate if the input image type is among the allowed types
if not utils.is_image_type_allowed(
input_image_type, num_channels, allowed_image_types
):
# Format the allowed types for display in the warning message
allowed_types_formatted = ", ".join(map(str, allowed_image_types))
# Display a warning message specifying the acceptable image types
st.warning(
f"Error applying {augmentation}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}",
icon="⚠️",
)
return True, None # Error occurred
# Set the seed for reproducibility using iteration number
random.seed(0)
if label is None:
# Apply augmentation technique for no label input
transform = A.Compose([utils.apply_albumentation(params, augmentation)])
processed_image = transform(image=image)["image"]
return False, processed_image
elif label_type == "Bboxes":
# Apply augmentation technique for Bboxes lable format
transform = A.Compose(
[utils.apply_albumentation(params, augmentation)],
bbox_params=A.BboxParams(
format="yolo",
label_fields=["class_labels"],
min_area=label_input_parameters["min_area"],
min_visibility=label_input_parameters["min_visibility"],
min_width=label_input_parameters["min_width"],
min_height=label_input_parameters["min_height"],
check_each_transform=label_input_parameters["check_each_transform"],
),
)
processed_image = transform(
image=image,
bboxes=label["bboxes"],
class_labels=label["class_labels"],
)["image"]
elif label_type == "Masks":
# Apply augmentation technique for Masks lable format
transform = A.Compose([utils.apply_albumentation(params, augmentation)])
processed_image = transform(
image=image,
mask=generate_mask(
label["masks"],
label["class_labels"],
image.shape[0],
image.shape[1],
),
)["image"]
return False, processed_image # No error
except Exception as e:
st.warning(f"Error applying {augmentation}: {e}", icon="⚠️")
return True, None # Error occurred
# Generates a DataFrame detailing augmentation technique parameters and descriptions
def create_augmentations_dataframe(augmentations_params, augmentation_params_db):
data = []
for aug_name, params in augmentations_params.items():
for param_name, param_value in params.items():
# Retrieve relevant augmentation information from the database
augmentation_info = augmentation_params_db[
augmentation_params_db["Name"] == aug_name
]
param_info = augmentation_info[
augmentation_info["Parameter Name"] == param_name
]
# Check if the parameter information exists in the database
if not param_info.empty:
# Get the description of the current parameter
param_description = param_info["Parameter Description"].iloc[0]
else:
param_description = "Description not available"
# Append augmentation name, parameter name, its value, and description to the data list
data.append([aug_name, param_name, param_value, param_description])
# Create the DataFrame from the accumulated data
augmentations_df = pd.DataFrame(
data, columns=["augmentation", "Parameter", "Value", "Description"]
)
return augmentations_df
# Function to Generate Python Code for Augmentation with Bounding Box Labels
def generate_python_code_bboxes(
augmentations_params,
label_input_parameters,
num_variations=1,
include_original=False,
):
# Start with necessary library imports
code_str = "# Importing necessary libraries\n"
code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
# Paths for input and output directories
code_str += "# Define the paths for input and output directories\n"
code_str += "input_directory = 'path/to/input'\n"
code_str += "output_directory = 'path/to/output'\n\n"
# Function to read YOLO format labels
code_str += "# Function to read YOLO format labels\n"
code_str += "def read_yolo_label(label_path):\n"
code_str += " bboxes = []\n"
code_str += " class_ids = []\n"
code_str += " with open(label_path, 'r') as file:\n"
code_str += " for line in file:\n"
code_str += " class_id, x_center, y_center, width, height = map(float, line.split())\n"
code_str += " bboxes.append([x_center, y_center, width, height])\n"
code_str += " class_ids.append(int(class_id))\n"
code_str += " return bboxes, class_ids\n\n"
# Function to create an augmentation pipeline
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
code_str += "def process_image(image, bboxes, class_ids):\n"
code_str += " # Define the sequence of augmentation techniques\n"
code_str += " pipeline = A.Compose([\n"
for technique, params in augmentations_params.items():
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
code_str += " ], bbox_params=A.BboxParams(\n"
code_str += f" format='yolo',\n"
code_str += f" label_fields=['class_labels'],\n"
code_str += f" min_area={label_input_parameters['min_area']},\n"
code_str += f" min_visibility={label_input_parameters['min_visibility']},\n"
code_str += f" min_width={label_input_parameters['min_width']},\n"
code_str += f" min_height={label_input_parameters['min_height']},\n"
code_str += f" check_each_transform={label_input_parameters['check_each_transform']}\n"
code_str += " ))\n"
code_str += " # Apply the augmentation pipeline\n"
code_str += (
" return pipeline(image=image, bboxes=bboxes, class_labels=class_ids)\n\n"
)
# Function to process a batch of images
code_str += "# Function to process a batch of images\n"
code_str += "def process_batch(input_directory, output_directory):\n"
code_str += " for filename in os.listdir(input_directory):\n"
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
code_str += " image_path = os.path.join(input_directory, filename)\n"
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
code_str += " # Read the image and label\n"
code_str += " image = cv2.imread(image_path)\n"
code_str += " bboxes, class_ids = read_yolo_label(label_path)\n\n"
# Include original image and label logic
code_str += " # Include original image and label\n"
if include_original:
code_str += " shutil.copy2(image_path, output_directory)\n"
code_str += " shutil.copy2(label_path, output_directory)\n\n"
# Generate variations for each image and process them
code_str += " # Generate variations for each image and process them\n"
code_str += f" for variation in range({num_variations}):\n"
code_str += (
" processed_data = process_image(image, bboxes, class_ids)\n"
)
code_str += " processed_image = processed_data['image']\n"
code_str += " processed_bboxes = processed_data['bboxes']\n"
code_str += (
" processed_class_ids = processed_data['class_labels']\n\n"
)
code_str += " # Save the processed image\n"
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n"
code_str += " for bbox, class_id in zip(processed_bboxes, processed_class_ids):\n"
code_str += " label_line = ' '.join(map(str, [class_id] + list(bbox)))\n"
code_str += " label_file.write(label_line + '\\n')\n\n"
# Execute the batch processing function
code_str += (
"# Execute the batch processing function with the specified parameters\n"
)
code_str += f"process_batch(input_directory, output_directory)\n"
return code_str
def generate_python_code_masks(
augmentations_params,
label_input_parameters,
num_variations=1,
include_original=False,
):
# Start with necessary library imports
code_str = "# Importing necessary libraries\n"
code_str += "import os\nimport cv2\nimport shutil\nimport numpy as np\nimport albumentations as A\n\n"
# Paths for input and output directories
code_str += "# Define the paths for input and output directories\n"
code_str += "input_directory = 'path/to/input'\n"
code_str += "output_directory = 'path/to/output'\n\n"
# Function to read YOLO mask format and convert to mask
code_str += "# Function to read YOLO mask format and convert to mask\n"
code_str += "def read_yolo_label(label_path, image_shape):\n"
code_str += " mask = np.full(image_shape, -1, dtype=np.int32)\n"
code_str += " with open(label_path, 'r') as file:\n"
code_str += " for line in file:\n"
code_str += " parts = line.strip().split()\n"
code_str += " class_id = int(parts[0])\n"
code_str += (
" points = np.array([float(p) for p in parts[1:]]).reshape(-1, 2)\n"
)
code_str += " points = (points * [image_shape[1], image_shape[0]]).astype(np.int32)\n"
code_str += " cv2.fillPoly(mask, [points], class_id)\n"
code_str += " return mask\n\n"
# Function to convert mask to YOLO format
code_str += "# Function to convert mask to YOLO format\n"
code_str += "def mask_to_yolo(mask):\n"
code_str += " yolo_format = ''\n"
code_str += " for class_id in np.unique(mask):\n"
code_str += " if class_id == -1:\n"
code_str += " continue\n"
code_str += " contours, _ = cv2.findContours(\n"
code_str += " np.uint8(mask == class_id), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n"
code_str += " )\n"
code_str += " for contour in contours:\n"
code_str += " contour = contour.flatten().tolist()\n"
code_str += " normalized_contour = [\n"
code_str += " str(coord / mask.shape[i % 2]) for i, coord in enumerate(contour)\n"
code_str += " ]\n"
code_str += " yolo_format += f'{class_id} ' + ' '.join(normalized_contour) + '\\n'\n"
code_str += " return yolo_format\n\n"
# Function to create an augmentation pipeline
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
code_str += "def process_image(image, mask):\n"
code_str += " # Define the sequence of augmentation techniques\n"
code_str += " pipeline = A.Compose([\n"
for technique, params in augmentations_params.items():
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
code_str += " ])\n"
code_str += " # Apply the augmentation pipeline\n"
code_str += " return pipeline(image=image, mask=mask)\n\n"
# Function to process a batch of images
code_str += "# Function to process a batch of images\n"
code_str += "def process_batch(input_directory, output_directory, include_original=False, num_variations=1):\n"
code_str += " for filename in os.listdir(input_directory):\n"
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
code_str += " image_path = os.path.join(input_directory, filename)\n"
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
code_str += " # Read the image\n"
code_str += " image = cv2.imread(image_path)\n\n"
# Include original image and label logic
code_str += " # Include original image\n"
if include_original:
code_str += " shutil.copy2(image_path, output_directory)\n"
code_str += " shutil.copy2(label_path, output_directory)\n\n"
code_str += " # Check if label file exists and read mask\n"
code_str += " mask = None\n"
code_str += " if os.path.exists(label_path):\n"
code_str += (
" mask = read_yolo_label(label_path, image.shape[:2])\n\n"
)
# Generate variations for each image and process them
code_str += " # Generate variations for each image and process them\n"
code_str += f" for variation in range({num_variations}):\n"
code_str += " processed_image, processed_mask = image, mask\n"
code_str += " if mask is not None:\n"
code_str += " processed_data = process_image(image, mask)\n"
code_str += " processed_image, processed_mask = processed_data['image'], processed_data['mask']\n\n"
code_str += " # Save the processed image\n"
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}.jpg'\n"
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
code_str += " # Save the processed label in YOLO format\n"
code_str += " if processed_mask is not None:\n"
code_str += (
" processed_label_str = mask_to_yolo(processed_mask)\n"
)
code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n"
code_str += " label_file.write(processed_label_str)\n\n"
# Execute the batch processing function with the specified parameters
code_str += (
"# Execute the batch processing function with the specified parameters\n"
)
code_str += "process_batch(input_directory, output_directory)\n"
return code_str
# Function to create an augmentation pipeline based on the selected techniques and their parameters
def create_augmentation_pipeline(
selected_augmentations, augmentation_params, label_type, label_input_parameters=None
):
pipeline = []
for aug_name in selected_augmentations:
# Append the function call with its parameters to the pipeline
pipeline.append(
utils.apply_albumentation(augmentation_params[aug_name], aug_name)
)
# Compose all the augmentations into one transformation
try:
# Set the seed for reproducibility using iteration number
random.seed(0)
if label_type is None:
# Apply augmentation technique for no label input
transform = A.Compose(pipeline)
return transform
elif label_type == "Bboxes":
# Apply augmentation technique for Bboxes lable format
transform = A.Compose(
pipeline,
bbox_params=A.BboxParams(
format="yolo",
label_fields=["class_labels"],
min_area=label_input_parameters["min_area"],
min_visibility=label_input_parameters["min_visibility"],
min_width=label_input_parameters["min_width"],
min_height=label_input_parameters["min_height"],
check_each_transform=label_input_parameters["check_each_transform"],
),
)
elif label_type == "Masks":
# Apply augmentation technique for Masks lable format
transform = A.Compose(pipeline)
return transform # No error
except Exception as e:
st.warning(f"Error applying augmentation")
return None # Error occurred
# Function to convert label data from dictionary format to YOLO format
def convert_labels_to_yolo_format(label_data, class_dict):
yolo_label_str = ""
# Convert bounding boxes to YOLO format
if "bboxes" in label_data:
for bbox, class_label in zip(label_data["bboxes"], label_data["class_labels"]):
class_id = class_label - 1 # Revert to the original value
x_center, y_center, width, height = bbox
yolo_label_str += f"{class_id} {x_center} {y_center} {width} {height}\n"
# Convert masks to YOLO format
if "masks" in label_data:
for mask, class_label in zip(label_data["masks"], label_data["class_labels"]):
class_id = class_label - 1 # Revert to the original value
# Flatten the mask array into a single line of coordinates
mask_flattened = [coord for point in mask for coord in point]
mask_str = " ".join(map(str, mask_flattened))
yolo_label_str += f"{class_id} {mask_str}\n"
return yolo_label_str
# Function to apply the augmentation pipeline to the image based on the label type
def apply_augmentation_pipeline(image, label_file, label_type, transform):
# Initialize an empty dictionary to store processed labels
processed_label = {}
# Apply the transformation based on the label type
if label_type is None:
processed_output = transform(image=image)
processed_label = None
elif label_type == "Bboxes":
processed_output = transform(
image=image,
bboxes=label_file["bboxes"],
class_labels=label_file["class_labels"],
)
processed_label = {
"bboxes": processed_output["bboxes"],
"class_labels": processed_output["class_labels"],
}
elif label_type == "Masks":
mask = generate_mask(
label_file["masks"],
label_file["class_labels"],
image.shape[0],
image.shape[1],
)
processed_output = transform(image=image, mask=mask)
mask_yolo = mask_to_yolo(processed_output["mask"])
processed_label = mask_yolo
# Extract the processed image
processed_image = processed_output["image"]
return processed_image, processed_label
# Function to process images and labels, apply augmentations, and create a zip file with the results
@st.cache_resource(show_spinner=False)
def process_images_and_labels(
image_files,
label_files,
selected_augmentations,
_augmentations_params,
label_type,
label_input_parameters,
num_variations,
include_original,
class_dict,
):
zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file
st.session_state[
"image_repository_augmentation"
] = {} # Initialize a repository to store processed image data
st.session_state[
"processed_image_mapping_augmentation"
] = {} # Map original images to their processed versions
st.session_state["unique_images_names"] = [] # List to store unique images names
# Create progress bar and text elements in Streamlit
progress_bar = st.progress(0)
progress_text = st.empty()
with zipfile.ZipFile(
zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True
) as zip_file:
# Determine the label type for augmentation, if label files are present
effective_label_type = None if len(label_files) == 0 else label_type
# Create an augmentation pipeline based on selected augmentations and parameters
transform = create_augmentation_pipeline(
selected_augmentations,
_augmentations_params,
effective_label_type,
label_input_parameters,
)
total_images = len(image_files) * num_variations
processed_count = 0 # Counter for processed images
# Iterate over each uploaded file
for image_name, image_file in image_files.items():
image_file.seek(0) # Reset file pointer to start
file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
original_image = cv2.cvtColor(
cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
)
original_image_resized = utils.resize_image(original_image)
# Include original images and labels in the output if selected
if include_original:
original_img_buffer = io.BytesIO()
Image.fromarray(original_image).save(original_img_buffer, format="JPEG")
zip_file.writestr(image_file.name, original_img_buffer.getvalue())
# Convert and save original labels to YOLO format if they exist
label_file = label_files.get(image_name)
if label_file is not None:
yolo_label_str = convert_labels_to_yolo_format(
label_file, class_dict
)
zip_file.writestr(f"{image_name}.txt", yolo_label_str)
original_file_name = image_file.name
st.session_state["unique_images_names"].append(original_file_name)
st.session_state["processed_image_mapping_augmentation"][
original_file_name
] = []
st.session_state["image_repository_augmentation"][image_file.name] = {
"image": original_image_resized,
"label": label_files.get(image_name),
}
# Apply augmentations and generate variations
for i in range(num_variations):
random.seed(i)
(
processed_image,
processed_label,
) = apply_augmentation_pipeline(
original_image,
label_files.get(image_name),
effective_label_type,
transform,
)
img_buffer = io.BytesIO()
Image.fromarray(processed_image).save(img_buffer, format="JPEG")
processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg"
zip_file.writestr(processed_filename, img_buffer.getvalue())
processed_image_resized = utils.resize_image(processed_image)
st.session_state["processed_image_mapping_augmentation"][
image_file.name
].append(processed_filename)
st.session_state["image_repository_augmentation"][
processed_filename
] = {
"image": processed_image_resized,
"label": processed_label,
}
# Convert and save processed labels to YOLO format if they exist
label_file = label_files.get(image_name)
if label_file is not None:
processed_label_str = convert_labels_to_yolo_format(
processed_label, class_dict
)
zip_file.writestr(
f"processed_{image_name.split('.')[0]}_{i}.txt",
processed_label_str,
)
processed_count += 1
# Update progress bar and text
progress_bar.progress(processed_count / total_images)
progress_text.text(
f"Processing image {processed_count} of {total_images}"
)
# Remove the progress bar and text after processing is complete
progress_bar.empty()
progress_text.empty()
zip_buffer.seek(0) # Reset buffer to start for download
st.session_state["zip_data_augmentation"] = zip_buffer.getvalue()
# Function to overlay labels on images
def overlay_labels(image, labels_to_plot, label_file, label_type, colors, class_dict):
# Ensure the image is in the correct format (RGB)
if len(image.shape) == 2 or image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Overlay Bounding Boxes
if label_type == "Bboxes":
for bbox, label in zip(label_file["bboxes"], label_file["class_labels"]):
if label in labels_to_plot:
# Convert bbox from yolo format to xmin, ymin, xmax, ymax
x_center, y_center, width, height = bbox
xmin = int((x_center - width / 2) * image.shape[1])
xmax = int((x_center + width / 2) * image.shape[1])
ymin = int((y_center - height / 2) * image.shape[0])
ymax = int((y_center + height / 2) * image.shape[0])
# Get color for the class
color = colors[label]
# Draw rectangle and label
image = cv2.rectangle(
image, (xmin, ymin), (xmax, ymax), color, thickness=2
)
# Put class label text
label_text = class_dict.get(label, "Unknown")
cv2.putText(
image,
label_text,
(xmin, ymin - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0), # Black color for text
2,
)
# Overlay Mask
elif label_type == "Masks":
for polygon, label in zip(label_file["masks"], label_file["class_labels"]):
if label in labels_to_plot:
# Convert polygon points from yolo format to image coordinates
polygon_points = [
(int(x * image.shape[1]), int(y * image.shape[0]))
for x, y in polygon
]
# Get color for the class
color = colors[label]
# Create a temporary image to draw the polygon
temp_image = image.copy()
cv2.fillPoly(
temp_image, [np.array(polygon_points, dtype=np.int32)], color
)
# Blend the temporary image with the original image
cv2.addWeighted(temp_image, 0.5, image, 0.5, 0, image)
# Optional: Put class label text near the first point of the polygon
label_text = class_dict.get(label, "Unknown")
cv2.putText(
image,
label_text,
polygon_points[0],
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0), # Black color for text
2,
)
return image
# Function to generate a downloadable file
def display_code_and_download_button(generated_code):
def generate_downloadable_file(code_str):
return code_str.encode("utf-8")
# Display the generated code in Streamlit with description and download button in columns
with st.expander("Plug and Play Code"):
col1, col2 = st.columns([7, 3])
with col1:
st.markdown(
"""
### Description of the Code Pipeline
"""
)
st.markdown(
"""
<div style='text-align: justify;'>
This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory.
**To use this script:**
- Ensure you have the necessary dependencies installed.
- Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images.
- The number of augmented variations per image, the inclusion of the original images in the output, and the augmentation techniques with their parameters will be automatically set based on your selections.
### Python Code
</div>
""",
unsafe_allow_html=True,
)
# Display python code
st.code(generated_code, language="python")
with col2:
# Create a button for downloading the Python file
st.download_button(
label="Download Python File",
data=generate_downloadable_file(generated_code),
file_name="augmentation_script.py",
mime="text/plain",
use_container_width=True,
)