|
import os |
|
from PIL import Image |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
|
|
|
|
base_dir = 'data/chest_xray' |
|
val_dir = os.path.join(base_dir, 'val') |
|
normal_class_dir = os.path.join(val_dir, 'NORMAL') |
|
pneumonia_class_dir = os.path.join(val_dir, 'PNEUMONIA') |
|
|
|
|
|
def augment_images(class_directory, num_augmented_images): |
|
datagen = ImageDataGenerator( |
|
rescale=1. / 255, |
|
rotation_range=20, |
|
width_shift_range=0.2, |
|
height_shift_range=0.2, |
|
shear_range=0.2, |
|
zoom_range=0.2, |
|
horizontal_flip=True, |
|
fill_mode='nearest' |
|
) |
|
|
|
generator = datagen.flow_from_directory( |
|
directory=os.path.dirname(class_directory), |
|
target_size=(150, 150), |
|
batch_size=1, |
|
class_mode=None, |
|
shuffle=False, |
|
classes=[os.path.basename(class_directory)] |
|
) |
|
|
|
print(f"Found {generator.samples} images in {class_directory}") |
|
|
|
if generator.samples == 0: |
|
print("No images found in the directory.") |
|
return |
|
|
|
count = 0 |
|
|
|
while count < num_augmented_images: |
|
try: |
|
img_batch = generator.__next__() |
|
img = (img_batch[0] * 255).astype('uint8') |
|
img_pil = Image.fromarray(img) |
|
img_path = os.path.join(class_directory, f"augmented_{count}.png") |
|
img_pil.save(img_path) |
|
count += 1 |
|
except StopIteration: |
|
print("No more images to generate.") |
|
break |
|
|
|
print(f"Total augmented images created: {count}") |
|
|
|
|
|
|
|
num_augmented_images_normal = 2944 - 3875 |
|
num_augmented_images_pneumonia = 2944 - 1171 |
|
|
|
|
|
augment_images(normal_class_dir, max(num_augmented_images_normal, 0)) |
|
|
|
|
|
augment_images(pneumonia_class_dir, num_augmented_images_pneumonia) |
|
|