Spaces:
Sleeping
Sleeping
import os | |
from PIL import Image | |
from torch.utils.data import Dataset | |
class PlantVillageDataset(Dataset): | |
""" | |
PyTorch-compatible dataset for the cleaned and split PlantVillage dataset. | |
Directory structure should be: | |
root/ | |
crop1/ | |
disease1/ | |
img1.jpg | |
... | |
disease2/ | |
... | |
crop2/ | |
... | |
""" | |
def __init__(self, root_dir, transform=None): | |
""" | |
Args: | |
root_dir (str): Path to split directory (e.g., data/split/train) | |
transform (callable, optional): Transformations to apply to images | |
""" | |
self.root_dir = root_dir | |
self.transform = transform | |
self.samples = [] | |
self.class_to_idx = {} | |
self._prepare_dataset() | |
def _prepare_dataset(self): | |
""" | |
Scan directory and build (image_path, class_index) list | |
""" | |
class_names = [] | |
for crop in sorted(os.listdir(self.root_dir)): | |
crop_path = os.path.join(self.root_dir, crop) | |
if not os.path.isdir(crop_path): | |
continue | |
for disease in sorted(os.listdir(crop_path)): | |
disease_path = os.path.join(crop_path, disease) | |
if not os.path.isdir(disease_path): | |
continue # Safety check | |
class_name = f"{crop}___{disease}" | |
if class_name not in self.class_to_idx: | |
self.class_to_idx[class_name] = len(self.class_to_idx) | |
class_names.append(class_name) | |
label = self.class_to_idx[class_name] | |
for fname in os.listdir(disease_path): | |
if not fname.lower().endswith((".jpg", ".jpeg", ".png")): | |
continue | |
img_path = os.path.join(disease_path, fname) | |
self.samples.append((img_path, label)) | |
# print(f"[INFO] {len(self.samples)} images found across {len(self.class_to_idx)} classes.") | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
img_path, label = self.samples[idx] | |
image = Image.open(img_path).convert("RGB") | |
if self.transform: | |
image = self.transform(image) | |
return image, label |