geekyrakshit commited on
Commit
ecc0237
·
1 Parent(s): 192c48a

updated mirnet

Browse files
Files changed (1) hide show
  1. enhance_me/mirnet/mirnet.py +5 -2
enhance_me/mirnet/mirnet.py CHANGED
@@ -26,6 +26,8 @@ class MIRNet:
26
  experiment_name: str,
27
  image_size: int = 256,
28
  dataset_label: str = "lol",
 
 
29
  apply_random_horizontal_flip: bool = True,
30
  apply_random_vertical_flip: bool = True,
31
  apply_random_rotation: bool = True,
@@ -33,7 +35,8 @@ class MIRNet:
33
  ) -> None:
34
  self.experiment_name = experiment_name
35
  if dataset_label == "lol":
36
- download_lol_dataset()
 
37
  self.data_loader = LowLightDataset(
38
  image_size=image_size,
39
  apply_random_horizontal_flip=apply_random_horizontal_flip,
@@ -46,7 +49,7 @@ class MIRNet:
46
  else:
47
  self.using_wandb = False
48
 
49
- def build_datasets(
50
  self,
51
  low_light_images: List[str],
52
  enhanced_images: List[str],
 
26
  experiment_name: str,
27
  image_size: int = 256,
28
  dataset_label: str = "lol",
29
+ val_split: float = 0.2,
30
+ batch_size: int = 16,
31
  apply_random_horizontal_flip: bool = True,
32
  apply_random_vertical_flip: bool = True,
33
  apply_random_rotation: bool = True,
 
35
  ) -> None:
36
  self.experiment_name = experiment_name
37
  if dataset_label == "lol":
38
+ (low_images, enhanced_images), (self.test_low_images, self.test_enhanced_images) = download_lol_dataset()
39
+ self._build_datasets(low_images, enhanced_images, )
40
  self.data_loader = LowLightDataset(
41
  image_size=image_size,
42
  apply_random_horizontal_flip=apply_random_horizontal_flip,
 
49
  else:
50
  self.using_wandb = False
51
 
52
+ def _build_datasets(
53
  self,
54
  low_light_images: List[str],
55
  enhanced_images: List[str],