csgaobb commited on
Commit
020dd6e
·
1 Parent(s): aa603f8

init MetaUAS

Browse files
Files changed (4) hide show
  1. app.py +160 -0
  2. demo_metauas.py +90 -0
  3. metauas.py +293 -0
  4. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ '''
4
+ @File : app.py
5
+ @Time : 2025/03/26 23:48:24
6
+ @Author : Bin-Bin Gao
7
+ @Email : [email protected]
8
+ @Homepage: https://csgaobb.github.io/
9
+ @Version : 1.0
10
+ @Desc : MetaUAS Demo with Gradio
11
+ '''
12
+
13
+
14
+ import os
15
+ import cv2
16
+ import torch
17
+ import json
18
+ import shutil
19
+ import kornia as K
20
+ import numpy as np
21
+ import gradio as gr
22
+ from easydict import EasyDict
23
+ from argparse import ArgumentParser
24
+ from torchvision.transforms.functional import pil_to_tensor
25
+
26
+ from metauas import MetaUAS, set_random_seed, normalize, apply_ad_scoremap, safely_load_state_dict
27
+
28
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+
30
+ # configurations
31
+ random_seed = 1
32
+ encoder_name = 'efficientnet-b4'
33
+ decoder_name = 'unet'
34
+ encoder_depth = 5
35
+ decoder_depth = 5
36
+ num_alignment_layers = 3
37
+ alignment_type = 'sa'
38
+ fusion_policy = 'cat'
39
+
40
+
41
+ # build model
42
+ set_random_seed(random_seed)
43
+ metauas_model = MetaUAS(encoder_name,
44
+ decoder_name,
45
+ encoder_depth,
46
+ decoder_depth,
47
+ num_alignment_layers,
48
+ alignment_type,
49
+ fusion_policy
50
+ )
51
+
52
+ def process_image(prompt_img, query_img, options):
53
+ # Load the model based on selected options
54
+ if 'model-512' in options:
55
+ ckt_path = "weights/metauas-512.ckpt"
56
+ model = safely_load_state_dict(metauas_model, ckt_path)
57
+ img_size = 512
58
+ else:
59
+ ckt_path = 'weights/metauas-256.ckpt'
60
+ model = safely_load_state_dict(metauas_model, ckt_path)
61
+ img_size = 256
62
+
63
+ model.to(device)
64
+ model.eval()
65
+
66
+ # Ensure image is in RGB mode
67
+ prompt_img = prompt_img.convert('RGB')
68
+ query_img = query_img.convert('RGB')
69
+
70
+ query_img = pil_to_tensor(query_img).float() / 255.0
71
+ prompt_img = pil_to_tensor(prompt_img).float() / 255.0
72
+
73
+ if query_img.shape[1] != img_size:
74
+ resize_trans = K.augmentation.Resize([img_size, img_size], return_transform=True)
75
+ query_img = resize_trans(query_img)[0]
76
+ prompt_img = resize_trans(prompt_img)[0]
77
+
78
+
79
+ test_data = {
80
+ "query_image": query_img.to(device),
81
+ "prompt_image": prompt_img.to(device),
82
+ }
83
+
84
+
85
+ # Forward
86
+ with torch.no_grad():
87
+ predicted_masks = model(test_data)
88
+ anomaly_score = predicted_masks[:].max()
89
+
90
+ # Process anomaly map
91
+ query_img = test_data["query_image"][0] * 255
92
+ query_img = query_img.permute(1,2,0)
93
+
94
+ anomaly_map = predicted_masks.squeeze().detach()[:, :, None].cpu().numpy().repeat(3, 2)
95
+
96
+ anomaly_map_vis = apply_ad_scoremap(query_img.cpu(), normalize(anomaly_map))
97
+
98
+
99
+ anomaly_map = (anomaly_map * 255).astype(np.uint8)
100
+ anomaly_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
101
+ anomaly_map = cv2.cvtColor(anomaly_map, cv2.COLOR_BGR2RGB)
102
+
103
+ return anomaly_map_vis, anomaly_map, f'{anomaly_score:.3f}'
104
+
105
+ # Define examples
106
+ examples = [
107
+ ["images/134.png", "images/000.png", "model-256"],
108
+ ["images/036.png", "images/024.png", "model-256"],
109
+ ["images/178.png", "images/003.png", "model-256"],
110
+ ]
111
+
112
+ # Gradio interface layout
113
+ with gr.Blocks() as demo:
114
+ gr.HTML("""<h1 align="center" style='margin-top: 30px;'>MetaUAS: Universal Anomaly Segmentation</h1>""")
115
+ gr.HTML("""<h1 align="center" style="font-size: 15px; "style='margin-top: 40px;'>just given ONE normal image prompt</h1>""")
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ with gr.Row():
120
+ prompt_image = gr.Image(type="pil", label="Prompt Image")
121
+ query_image = gr.Image(type="pil", label="Query Image")
122
+ model_selector = gr.Radio(["model-256", "model-512"], label="Pre-models")
123
+
124
+ with gr.Column():
125
+ with gr.Row():
126
+ anomaly_map_vis = gr.Image(type="pil", label="Anomaly Results")
127
+ anomaly_map = gr.Image(type="pil", label="Anomaly Maps")
128
+ anomaly_score = gr.Textbox(label="Anomaly Score")
129
+
130
+ with gr.Row():
131
+ submit_button = gr.Button("Submit", elem_id="submit-button")
132
+ clear_button = gr.Button("Clear")
133
+
134
+ # Set up the event handlers
135
+ submit_button.click(process_image, inputs=[prompt_image, query_image, model_selector], outputs=[anomaly_map_vis, anomaly_map, anomaly_score])
136
+ clear_button.click(lambda: (None, None, None), outputs=[anomaly_map_vis, anomaly_map, anomaly_score])
137
+
138
+ # Add examples directly to the Blocks interface
139
+ gr.Examples(examples, inputs=[prompt_image, query_image, model_selector])
140
+
141
+ # Add custom CSS to control the output image size and button styles
142
+ demo.css = """
143
+ #submit-button {
144
+ color: red !important; /* Font color */
145
+ background-color: orange !important; /* Background color */
146
+ border: none !important; /* Remove border */
147
+ padding: 10px 20px !important; /* Add padding */
148
+ border-radius: 5px !important; /* Rounded corners */
149
+ font-size: 16px !important; /* Font size */
150
+ cursor: pointer !important; /* Pointer cursor on hover */
151
+ }
152
+
153
+ #submit-button:hover {
154
+ background-color: darkorange !important; /* Darker orange on hover */
155
+ }
156
+ """
157
+
158
+ # Launch the demo
159
+ demo.launch()
160
+
demo_metauas.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ '''
4
+ @File : demo_metauas.py
5
+ @Time : 2025/03/26 23:49:14
6
+ @Author : Bin-Bin Gao
7
+ @Email : [email protected]
8
+ @Homepage: https://csgaobb.github.io/
9
+ @Version : 1.0
10
+ @Desc : MetaUAS Demo
11
+ '''
12
+
13
+
14
+ import os
15
+ import cv2
16
+ import torch
17
+ import json
18
+ import shutil
19
+ import kornia as K
20
+ import numpy as np
21
+
22
+ from easydict import EasyDict
23
+ from argparse import ArgumentParser
24
+ from metauas import MetaUAS, set_random_seed, normalize, apply_ad_scoremap, read_image_as_tensor, safely_load_state_dict
25
+
26
+ if __name__ == "__main__":
27
+ random_seed = 1
28
+
29
+ set_random_seed(random_seed)
30
+
31
+ ckt_path = 'weights/metauas-256.ckpt'
32
+ img_size = 256
33
+ #ckt_path = "weights/metauas-512.ckpt"
34
+ #img_size = 512
35
+
36
+ # load model
37
+ encoder = 'efficientnet-b4'
38
+ decoder = 'unet'
39
+ encoder_depth = 5
40
+ decoder_depth = 5
41
+ num_crossfa_layers = 3
42
+ alignment_type = 'sa'
43
+ fusion_policy = 'cat'
44
+
45
+ model = MetaUAS(encoder,
46
+ decoder,
47
+ encoder_depth,
48
+ decoder_depth,
49
+ num_crossfa_layers,
50
+ alignment_type,
51
+ fusion_policy
52
+ )
53
+
54
+
55
+ model = safely_load_state_dict(model, ckt_path)
56
+ model.cuda()
57
+ model.eval()
58
+
59
+
60
+ # load test images
61
+ path_root = "./images/"
62
+ path_to_prompt = path_root + "036.png"
63
+ path_to_query = path_root + "024.png"
64
+
65
+ query = read_image_as_tensor(path_to_query)
66
+ prompt = read_image_as_tensor(path_to_prompt)
67
+
68
+ if query.shape[1] != img_size:
69
+ resize_trans = K.augmentation.Resize([img_size, img_size], return_transform=True)
70
+ query = resize_trans(query)[0]
71
+ prompt = resize_trans(prompt)[0]
72
+
73
+
74
+ test_data = {
75
+ "query_image": query.cuda(),
76
+ "prompt_image": prompt.cuda(),
77
+ }
78
+
79
+ # forward
80
+ predicted_masks = model(test_data)
81
+
82
+ # visualization
83
+ query_img = test_data["query_image"][0] * 255
84
+ query_img = query_img.permute(1,2,0)
85
+
86
+ pred = (1-predicted_masks.squeeze().detach())[:, :, None].cpu().numpy().repeat(3, 2)
87
+ # normalize just for analysis
88
+ scoremap_self = apply_ad_scoremap(query_img.cpu(), normalize(pred))
89
+ cv2.imwrite('./anomaly_map.jpg', scoremap_self)
90
+
metauas.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ '''
4
+ @File : metauas.py
5
+ @Time : 2025/03/26 23:46:12
6
+ @Author : Bin-Bin Gao
7
+ @Email : [email protected]
8
+ @Homepage: https://csgaobb.github.io/
9
+ @Version : 1.0
10
+ @Desc : some classes and functions for MetaUAS
11
+ '''
12
+
13
+
14
+ import os
15
+ import random
16
+ import kornia as K
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import pytorch_lightning as pl
20
+ import torch
21
+ import torch.nn as nn
22
+ import tqdm
23
+ import time
24
+ import cv2
25
+
26
+ from PIL import Image
27
+ from einops import rearrange
28
+ from torch.nn import functional as F
29
+ from torchvision import transforms
30
+ from torchvision.transforms.functional import pil_to_tensor
31
+ from segmentation_models_pytorch.unet.model import UnetDecoder
32
+ from segmentation_models_pytorch.fpn.decoder import FPNDecoder
33
+ from segmentation_models_pytorch.encoders import get_encoder, get_preprocessing_params
34
+
35
+ def set_random_seed(seed=233, reproduce=False):
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed ** 2)
38
+ torch.cuda.manual_seed(seed ** 3)
39
+ random.seed(seed ** 4)
40
+
41
+ if reproduce:
42
+ torch.backends.cudnn.benchmark = False
43
+ torch.backends.cudnn.deterministic = True
44
+ else:
45
+ torch.backends.cudnn.benchmark = True
46
+
47
+ def normalize(pred, max_value=None, min_value=None):
48
+ if max_value is None or min_value is None:
49
+ return (pred - pred.min()) / (pred.max() - pred.min())
50
+ else:
51
+ return (pred - min_value) / (max_value - min_value)
52
+
53
+
54
+ def apply_ad_scoremap(image, scoremap, alpha=0.5):
55
+ np_image = np.asarray(image, dtype=np.float32)
56
+ scoremap = (scoremap * 255).astype(np.uint8)
57
+ scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
58
+ scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
59
+ return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
60
+
61
+
62
+ def read_image_as_tensor(path_to_image):
63
+ pil_image = Image.open(path_to_image).convert("RGB")
64
+ image_as_tensor = pil_to_tensor(pil_image).float() / 255.0
65
+ return image_as_tensor
66
+
67
+ def safely_load_state_dict(model, checkpoint):
68
+ model.load_state_dict(torch.load(checkpoint), strict=True)
69
+ return model
70
+
71
+
72
+ class AlignmentModule(nn.Module):
73
+ def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa", fusion_policy='cat'):
74
+ super().__init__()
75
+ self.fusion_policy = fusion_policy
76
+ self.alignment_layer = AlignmentLayer(input_channels, hidden_channels, alignment_type=alignment_type)
77
+
78
+ def forward(self, query_features, prompt_features):
79
+ if isinstance(prompt_features, list):
80
+ aligned_prompt = []
81
+ for i in range(len(prompt_features)):
82
+ weighted_prompt.append(self.alignment_layer(query_features, prompt_features[i]))
83
+ aligned_prompt = torch.mean(torch.stack(aligned_prompt),0)
84
+
85
+ else:
86
+ aligned_prompt = self.alignment_layer(query_features, prompt_features)
87
+
88
+ if self.fusion_policy == 'cat':
89
+ query_features = rearrange(
90
+ [query_features, aligned_prompt], "two b c h w -> b (two c) h w"
91
+ )
92
+ elif self.fusion_policy == 'add':
93
+ query_features = query_features + aligned_prompt
94
+
95
+ elif self.fusion_policy == 'absdiff':
96
+ query_features = (query_features - aligned_prompt).abs()
97
+
98
+ return query_features
99
+
100
+ class AlignmentLayer(nn.Module):
101
+ def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa"):
102
+ super().__init__()
103
+ self.alignment_type = alignment_type
104
+ if alignment_type != "na":
105
+ self.dimensionality_reduction = nn.Conv2d(
106
+ input_channels, hidden_channels, kernel_size=1, stride=1, padding=0, bias=True
107
+ )
108
+
109
+ def forward(self, query_features, prompt_features):
110
+ # no-alignment
111
+ if self.alignment_type == 'na':
112
+ return prompt_features
113
+ else:
114
+ Q = self.dimensionality_reduction(query_features)
115
+ K = self.dimensionality_reduction(prompt_features)
116
+ V = rearrange(prompt_features, "b c h w -> b c (h w)")
117
+
118
+ soft_attention_map = torch.einsum("bcij,bckl->bijkl", Q, K)
119
+ soft_attention_map = rearrange(soft_attention_map, "b h1 w1 h2 w2 -> b h1 w1 (h2 w2)")
120
+ soft_attention_map = nn.Softmax(dim=3)(soft_attention_map)
121
+
122
+ # soft-alignment
123
+ if self.alignment_type == 'sa':
124
+ aligned_features = torch.einsum("bijp,bcp->bcij", soft_attention_map, V)
125
+ # hard-alignment
126
+ if self.alignment_type == 'ha':
127
+ max_v, max_index = attention_map.max(dim=-1, keepdim=True)
128
+ hard_attention_map = (attention_map == max_v).float()
129
+ aligned_features = torch.einsum("bijp,bcp->bcij", hard_attention_map, V)
130
+
131
+ return aligned_features
132
+
133
+
134
+ class MetaUAS(pl.LightningModule):
135
+ def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
136
+ super().__init__()
137
+
138
+ self.encoder_name = encoder_name
139
+ self.decoder_name = decoder_name
140
+ self.encoder_depth = encoder_depth
141
+ self.decoder_depth = decoder_depth
142
+
143
+ self.num_alignment_layers = num_alignment_layers
144
+ self.alignment_type = alignment_type
145
+ self.fusion_policy = fusion_policy
146
+
147
+
148
+ align_input_channels = [448, 160, 56]
149
+ align_hidden_channels = [224, 80, 28]
150
+ encoder_channels = [3, 48, 32, 56, 160, 448]
151
+ decoder_channels = [256, 128, 64, 64, 48]
152
+
153
+ self.encoder = get_encoder(
154
+ self.encoder_name,
155
+ in_channels=3,
156
+ depth=self.encoder_depth,
157
+ weights="imagenet",)
158
+
159
+ preparams = get_preprocessing_params(
160
+ self.encoder_name,
161
+ pretrained="imagenet"
162
+ )
163
+
164
+ self.preprocess = transforms.Normalize(preparams['mean'], preparams['std'])
165
+
166
+ self.encoder.eval()
167
+ for param in self.encoder.parameters():
168
+ param.requires_grad = False
169
+
170
+ if self.decoder_name == "unet":
171
+ encoder_out_channels = encoder_channels[self.encoder_depth-self.decoder_depth:]
172
+ if self.fusion_policy == 'cat':
173
+ num_alignment_layers = self.num_alignment_layers
174
+ elif self.fusion_policy == 'add' or self.fusion_policy == 'absdiff':
175
+ num_alignment_layers = 0
176
+
177
+ self.decoder = UnetDecoder(
178
+ encoder_channels=encoder_out_channels,
179
+ decoder_channels=decoder_channels,
180
+ n_blocks= self.decoder_depth,
181
+ attention_type="scse",
182
+ num_coam_layers= num_alignment_layers,
183
+ )
184
+
185
+ elif self.decoder_name == "fpn":
186
+ encoder_out_channels = encoder_channels
187
+ if self.fusion_policy == 'cat':
188
+ for i in range(self.num_alignment_layers):
189
+ encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)]
190
+
191
+ self.decoder = FPNDecoder(
192
+ encoder_channels= encoder_out_channels,
193
+ encoder_depth=self.encoder_depth,
194
+ pyramid_channels=256,
195
+ segmentation_channels=decoder_channels[-1],
196
+ dropout=0.2,
197
+ merge_policy="add",
198
+ )
199
+
200
+ elif self.decoder_name == "fpnadd":
201
+ segmentation_channels = 256 #128
202
+ encoder_out_channels = encoder_channels
203
+ if self.fusion_policy == 'cat':
204
+ for i in range(self.num_alignment_layers):
205
+ encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)]
206
+
207
+ self.decoder = FPNDecoder(
208
+ encoder_channels= encoder_out_channels,
209
+ encoder_depth=self.encoder_depth,
210
+ pyramid_channels=256,
211
+ segmentation_channels=segmentation_channels,
212
+ dropout=0.2,
213
+ merge_policy="add",
214
+ )
215
+ elif self.decoder_name == "fpncat":
216
+ encoder_out_channels = encoder_channels
217
+ segmentation_channels = 256 #128
218
+ if self.fusion_policy == 'cat':
219
+ for i in range(self.num_alignment_layers):
220
+ encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)]
221
+
222
+ self.decoder = FPNDecoder(
223
+ encoder_channels= encoder_out_channels,
224
+ encoder_depth=self.encoder_depth,
225
+ pyramid_channels=256,
226
+ segmentation_channels=segmentation_channels,
227
+ dropout=0.2,
228
+ merge_policy="cat",
229
+ )
230
+
231
+
232
+ if self.alignment_type == "sa" or self.alignment_type == "na" or self.alignment_type == "ha" :
233
+ self.alignment = nn.ModuleList(
234
+ [
235
+ AlignmentModule(
236
+ input_channels=align_input_channels[i],
237
+ hidden_channels=align_hidden_channels[i],
238
+ alignment_type=self.alignment_type,
239
+ fusion_policy=self.fusion_policy,
240
+ )
241
+ for i in range(self.num_alignment_layers)
242
+ ]
243
+ )
244
+
245
+ if self.decoder_name == "fpncat":
246
+ self.mask_head = nn.Conv2d(
247
+ segmentation_channels*4,
248
+ 1,
249
+ kernel_size=1,
250
+ stride=1,
251
+ padding=0,
252
+ )
253
+ elif self.decoder_name == "fpnadd":
254
+ self.mask_head = nn.Conv2d(
255
+ segmentation_channels,
256
+ 1,
257
+ kernel_size=1,
258
+ stride=1,
259
+ padding=0,
260
+ )
261
+ else:
262
+ self.mask_head = nn.Conv2d(
263
+ decoder_channels[-1],
264
+ 1,
265
+ kernel_size=1,
266
+ stride=1,
267
+ padding=0,
268
+ )
269
+
270
+ def forward(self, batch):
271
+ query_input = self.preprocess(batch["query_image"])
272
+ prompt_input = self.preprocess(batch["prompt_image"])
273
+
274
+ with torch.no_grad():
275
+ query_encoded_features = self.encoder(query_input)
276
+ prompt_encoded_features = self.encoder(prompt_input)
277
+
278
+ for i in range(len(self.alignment)):
279
+ query_encoded_features[-(i + 1)] = self.alignment[i](query_encoded_features[-(i + 1)], prompt_encoded_features[-(i + 1)])
280
+
281
+ query_decoded_features = self.decoder(*query_encoded_features[self.encoder_depth-self.decoder_depth:])
282
+
283
+ if self.decoder_name == "fpn" or self.decoder_name == "fpncat" or self.decoder_name == "fpnadd":
284
+ output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=4, mode="bilinear", align_corners=False)
285
+
286
+ elif self.decoder_name == "unet":
287
+ if self.decoder_depth == 4:
288
+ output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=2, mode="bilinear", align_corners=False)
289
+ if self.decoder_depth == 5:
290
+ if not self.training:
291
+ output = self.mask_head(query_decoded_features)
292
+
293
+ return output.sigmoid()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ easydict==1.11
2
+ einops==0.8.1
3
+ gradio==4.0.0
4
+ kornia==0.6.3
5
+ matplotlib==3.5.0
6
+ numpy==1.24.4
7
+ opencv_python==4.6.0.66
8
+ opencv_python_headless==4.7.0.72
9
+ Pillow==8.4.0
10
+ pytorch_lightning==1.9.0
11
+ segmentation_models_pytorch==0.2.1
12
+ torch==1.12.1+cu113
13
+ torchvision==0.13.1+cu113
14
+ tqdm==4.62.3