zhjohnchan commited on
Commit
a2a587d
·
verified ·
1 Parent(s): 7d44288

Update modeling_visual.py

Browse files
Files changed (1) hide show
  1. modeling_visual.py +2 -78
modeling_visual.py CHANGED
@@ -22,68 +22,6 @@ from transformers import AutoModel, AutoProcessor
22
  from transformers.activations import ACT2FN
23
 
24
 
25
- class TransformCXR(object):
26
- def __init__(
27
- self,
28
- image_size=448,
29
- mean=(0.48145466, 0.4578275, 0.40821073),
30
- std=(0.26862954, 0.26130258, 0.27577711),
31
- allow_shift=True,
32
- training=True,
33
- normalize=True
34
- ):
35
-
36
- resize_size = image_size
37
- p_train = 0.5
38
- shift_limit = (-0.0, 0.0)
39
- scale_limit = (-0.1, -0.02)
40
- rotate_limit = 5
41
- scale = (0.00, 0.01)
42
- brightness_limit = (-0.15, 0.15)
43
- contrast_limit = (-0.05, 0.05)
44
- pad_mode = cv2.BORDER_CONSTANT
45
- pad_val = (0, 0, 0)
46
-
47
- if training:
48
- if allow_shift:
49
- transform_list = [
50
- A.ShiftScaleRotate(
51
- shift_limit=shift_limit, scale_limit=scale_limit,
52
- rotate_limit=rotate_limit, border_mode=pad_mode, value=pad_val,
53
- p=p_train
54
- ),
55
- A.Perspective(
56
- scale=scale, pad_mode=pad_mode, pad_val=pad_val, p=p_train
57
- ),
58
- A.Resize(height=resize_size, width=resize_size, interpolation=cv2.INTER_CUBIC),
59
- A.RandomCrop(height=image_size, width=image_size),
60
- A.RandomBrightnessContrast(
61
- brightness_limit=brightness_limit, contrast_limit=contrast_limit,
62
- p=p_train
63
- )
64
- ]
65
- else:
66
- transform_list = [
67
- A.Resize(height=image_size, width=image_size, interpolation=cv2.INTER_CUBIC),
68
- A.RandomBrightnessContrast(
69
- brightness_limit=brightness_limit, contrast_limit=contrast_limit,
70
- p=p_train
71
- )
72
- ]
73
- else:
74
- transform_list = [
75
- A.Resize(height=image_size, width=image_size, interpolation=cv2.INTER_CUBIC)
76
- ]
77
-
78
- if normalize:
79
- transform_list += [A.Normalize(mean=mean, std=std), ToTensorV2(transpose_mask=True)]
80
-
81
- self.transforms = A.Compose(transform_list)
82
-
83
- def __call__(self, image):
84
- image = np.array(image)
85
- return self.transforms(image=image)['image']
86
-
87
 
88
  def get_abs_pos(abs_pos, tgt_size):
89
  # abs_pos: L, C
@@ -247,11 +185,7 @@ class CLIPModel(nn.Module):
247
  # Transforms
248
  self.mean = self.processor.image_mean
249
  self.std = self.processor.image_std
250
- self.image_transform_train = TransformCXR(image_size=image_size, mean=self.mean, std=self.std, training=True)
251
- self.image_transform_train_no_shift = TransformCXR(
252
- image_size=image_size, mean=self.mean, std=self.std, allow_shift=False, training=True
253
- )
254
- self.image_transform_val = TransformCXR(image_size=image_size, mean=self.mean, std=self.std, training=False)
255
  self.image_transform = transforms.Compose([
256
  transforms.Resize(
257
  (image_size, image_size),
@@ -298,17 +232,7 @@ class CLIPModel(nn.Module):
298
 
299
  image = image.convert("RGB")
300
 
301
- no_shift = any([keyword in image_path for keyword in ["vindr", "candid", "siim", "object-cxr", "ms-cxr"]])
302
- try:
303
- if training or self.training:
304
- if no_shift:
305
- image_tensor = self.image_transform_train_no_shift(image)
306
- else:
307
- image_tensor = self.image_transform_train(image)
308
- else:
309
- image_tensor = self.image_transform_val(image)
310
- except:
311
- image_tensor = self.image_transform(image)
312
  return image_tensor
313
 
314
  def encode(self, image_paths: List[str], training):
 
22
  from transformers.activations import ACT2FN
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def get_abs_pos(abs_pos, tgt_size):
27
  # abs_pos: L, C
 
185
  # Transforms
186
  self.mean = self.processor.image_mean
187
  self.std = self.processor.image_std
188
+
 
 
 
 
189
  self.image_transform = transforms.Compose([
190
  transforms.Resize(
191
  (image_size, image_size),
 
232
 
233
  image = image.convert("RGB")
234
 
235
+ image_tensor = self.image_transform(image)
 
 
 
 
 
 
 
 
 
 
236
  return image_tensor
237
 
238
  def encode(self, image_paths: List[str], training):