|
from lightning import LightningDataModule |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class Datamodule(LightningDataModule): |
|
def __init__( |
|
self, |
|
train_dataset: Dataset, |
|
eval_dataset: Dataset, |
|
batch_train_size: int, |
|
num_workers: int, |
|
eval_batch_size: int = None, |
|
): |
|
super().__init__() |
|
|
|
self.train_dataset = train_dataset |
|
self.eval_dataset = eval_dataset |
|
|
|
self.batch_train_size = batch_train_size |
|
self.eval_batch_size = ( |
|
eval_batch_size if eval_batch_size is not None else batch_train_size |
|
) |
|
|
|
self.num_workers = num_workers |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
"""Load train set loader.""" |
|
persistent_workers = True if self.num_workers > 0 else False |
|
|
|
dataloader = DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_train_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
persistent_workers=persistent_workers, |
|
) |
|
return dataloader |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
"""Load val set loader.""" |
|
persistent_workers = True if self.num_workers > 0 else False |
|
|
|
dataloader = DataLoader( |
|
self.eval_dataset, |
|
batch_size=self.eval_batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
persistent_workers=persistent_workers, |
|
) |
|
return dataloader |
|
|
|
def predict_dataloader(self) -> DataLoader: |
|
"""Load predict set loader.""" |
|
dataloader = DataLoader( |
|
self.eval_dataset, |
|
batch_size=self.eval_batch_size, |
|
num_workers=self.num_workers, |
|
) |
|
return dataloader |
|
|
|
def test_dataloader(self) -> DataLoader: |
|
"""Load test set loader.""" |
|
dataloader = DataLoader( |
|
self.eval_dataset, |
|
batch_size=self.eval_batch_size, |
|
num_workers=self.num_workers, |
|
) |
|
return dataloader |
|
|