AveMujica commited on
Commit
40aaca9
·
1 Parent(s): 3924e13
app.py CHANGED
@@ -5,6 +5,9 @@ import torch
5
  import gradio as gr
6
  import segmentation_models_pytorch as smp
7
  from PIL import Image
 
 
 
8
  from glob import glob
9
  from pipeline.ImgOutlier import detect_outliers
10
  from pipeline.normalization import align_images
@@ -12,13 +15,60 @@ from pipeline.normalization import align_images
12
  # 检测是否在Hugging Face环境中运行
13
  HF_SPACE = os.environ.get('SPACE_ID') is not None
14
 
15
- # 尝试导入上传模块,如果不在HF环境中才需要
16
- if not HF_SPACE:
 
 
 
 
 
 
 
 
 
 
17
  try:
18
- from uploader.do_spaces import upload_mask
19
- except ImportError:
20
- def upload_mask(image, prefix=""):
21
- return "上传模块未加载"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Global Configuration
24
  MODEL_PATHS = {
@@ -54,6 +104,8 @@ def load_model(model_path, device="cuda"):
54
  # 如果在HF环境中,默认使用CPU
55
  if HF_SPACE:
56
  device = "cpu" # HF Space可能没有GPU
 
 
57
 
58
  model = smp.create_model(
59
  "DeepLabV3Plus",
@@ -68,25 +120,33 @@ def load_model(model_path, device="cuda"):
68
  model.load_state_dict(state_dict)
69
  model.to(device)
70
  model.eval()
71
- print(f"Model load success: {model_path}")
72
  return model
73
  except Exception as e:
74
- print(f"Model load fail: {e}")
75
  return None
76
 
77
  # Load reference vector
78
  def load_reference_vector(vector_path):
79
  try:
 
 
 
80
  ref_vector = np.load(vector_path)
81
- print(f"reference vector load success: {vector_path}")
82
  return ref_vector
83
  except Exception as e:
84
- print(f"reference vector load {vector_path}: {e}")
85
  return []
86
 
87
  # Load reference image
88
  def load_reference_images(ref_dir):
89
  try:
 
 
 
 
 
90
  image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
91
  image_files = []
92
  for ext in image_extensions:
@@ -97,10 +157,10 @@ def load_reference_images(ref_dir):
97
  img = cv2.imread(file)
98
  if img is not None:
99
  reference_images.append(img)
100
- print(f"from {ref_dir} load {len(reference_images)} images")
101
  return reference_images
102
  except Exception as e:
103
- print(f"load image failed {ref_dir}: {e}")
104
  return []
105
 
106
  # Preprocess the image
@@ -173,7 +233,7 @@ def process_coastal_image(location, input_image):
173
  if model is None:
174
  return None, None, f"错误:无法加载模型", "未检测", None
175
 
176
- ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location]) if os.path.exists(REFERENCE_VECTOR_PATHS[location]) else []
177
  ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location])
178
 
179
  outlier_status = "未检测"
@@ -183,21 +243,23 @@ def process_coastal_image(location, input_image):
183
  if len(ref_vector) > 0:
184
  filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector)
185
  is_outlier = len(filtered) == 0
186
- else:
187
  filtered, _ = detect_outliers(ref_images, [image_bgr])
188
  is_outlier = len(filtered) == 0
 
 
 
189
 
190
  outlier_status = "异常检测: <span style='color:red;font-weight:bold'>未通过</span>" if is_outlier else "异常检测: <span style='color:green;font-weight:bold'>通过</span>"
191
  seg_map, overlay, analysis = perform_segmentation(model, image_bgr)
192
 
193
- # 在HF环境中不上传,只返回本地结果
194
  url = "本地存储"
195
- if not HF_SPACE:
196
- try:
197
- url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_'))
198
- except Exception as e:
199
- print(f"Upload failed: {e}")
200
- url = "上传错误"
201
 
202
  if is_outlier:
203
  analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>警告:图像未通过异常检测,结果可能不准确!</div>" + analysis
@@ -218,19 +280,22 @@ def process_with_alignment(location, reference_image, input_image):
218
  ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR)
219
  tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
220
 
221
- aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)])
222
- aligned_tgt_bgr = aligned[1]
 
 
 
 
223
 
224
  seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr)
225
 
226
- # 在HF环境中不上传,只返回本地结果
227
  url = "本地存储"
228
- if not HF_SPACE:
229
- try:
230
- url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_'))
231
- except Exception as e:
232
- print(f"Upload failed: {e}")
233
- url = "上传错误"
234
 
235
  status = "空间对齐: <span style='color:green;font-weight:bold'>完成</span>"
236
  ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB)
@@ -259,7 +324,7 @@ def create_interface():
259
  url1 = gr.Text(label="分割图URL")
260
  status1 = gr.HTML(label="异常检测状态")
261
  res1 = gr.HTML(label="分析结果")
262
- btn1.click(fn=process_coastal_image,inputs=[loc1, inp],outputs=[seg, ovl, res1, status1, url1])
263
 
264
  with gr.TabItem("空间对齐分割"):
265
  with gr.Row():
@@ -275,18 +340,33 @@ def create_interface():
275
  with gr.Row():
276
  seg2 = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h)
277
  ovl2 = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h)
278
- url2 = gr.Text(label="分割图URL")
279
  status2 = gr.HTML(label="空间对齐状态")
280
  res2 = gr.HTML(label="分析结果")
281
  btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2])
282
  return demo
283
 
284
  if __name__ == "__main__":
 
285
  for path in ["models", "reference_images/MM", "reference_images/SJ"]:
286
  os.makedirs(path, exist_ok=True)
 
 
287
  for p in MODEL_PATHS.values():
288
  if not os.path.exists(p):
289
  print(f"警告:模型文件 {p} 不存在!")
 
 
 
 
 
 
 
 
 
 
 
 
290
  demo = create_interface()
291
  # 在HF环境中使用适当的启动配置
292
  if HF_SPACE:
 
5
  import gradio as gr
6
  import segmentation_models_pytorch as smp
7
  from PIL import Image
8
+ import boto3
9
+ import uuid
10
+ import io
11
  from glob import glob
12
  from pipeline.ImgOutlier import detect_outliers
13
  from pipeline.normalization import align_images
 
15
  # 检测是否在Hugging Face环境中运行
16
  HF_SPACE = os.environ.get('SPACE_ID') is not None
17
 
18
+ # DigitalOcean Spaces上传函数
19
+ def upload_mask(image, prefix="mask"):
20
+ """
21
+ 将分割掩码图像上传到DigitalOcean Spaces
22
+
23
+ Args:
24
+ image: PIL Image对象
25
+ prefix: 文件名前缀
26
+
27
+ Returns:
28
+ 上传文件的URL
29
+ """
30
  try:
31
+ # 从环境变量获取凭据
32
+ do_key = os.environ.get('DO_SPACES_KEY')
33
+ do_secret = os.environ.get('DO_SPACES_SECRET')
34
+ do_region = os.environ.get('DO_SPACES_REGION')
35
+ do_bucket = os.environ.get('DO_SPACES_BUCKET')
36
+
37
+ # 校验凭据是否存在
38
+ if not all([do_key, do_secret, do_region, do_bucket]):
39
+ return "DigitalOcean凭据未设置"
40
+
41
+ # 创建S3客户端
42
+ session = boto3.session.Session()
43
+ client = session.client('s3',
44
+ region_name=do_region,
45
+ endpoint_url=f'https://{do_region}.digitaloceanspaces.com',
46
+ aws_access_key_id=do_key,
47
+ aws_secret_access_key=do_secret)
48
+
49
+ # 生成唯一文件名
50
+ filename = f"{prefix}_{uuid.uuid4().hex}.png"
51
+
52
+ # 将图像转换为字节流
53
+ img_byte_arr = io.BytesIO()
54
+ image.save(img_byte_arr, format='PNG')
55
+ img_byte_arr.seek(0)
56
+
57
+ # 上传到Spaces
58
+ client.upload_fileobj(
59
+ img_byte_arr,
60
+ do_bucket,
61
+ filename,
62
+ ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'}
63
+ )
64
+
65
+ # 返回公共URL
66
+ url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}'
67
+ return url
68
+
69
+ except Exception as e:
70
+ print(f"上传失败: {str(e)}")
71
+ return f"上传错误: {str(e)}"
72
 
73
  # Global Configuration
74
  MODEL_PATHS = {
 
104
  # 如果在HF环境中,默认使用CPU
105
  if HF_SPACE:
106
  device = "cpu" # HF Space可能没有GPU
107
+ elif not torch.cuda.is_available():
108
+ device = "cpu" # 本地环境也可能没有GPU
109
 
110
  model = smp.create_model(
111
  "DeepLabV3Plus",
 
120
  model.load_state_dict(state_dict)
121
  model.to(device)
122
  model.eval()
123
+ print(f"模型加载成功: {model_path}")
124
  return model
125
  except Exception as e:
126
+ print(f"模型加载失败: {e}")
127
  return None
128
 
129
  # Load reference vector
130
  def load_reference_vector(vector_path):
131
  try:
132
+ if not os.path.exists(vector_path):
133
+ print(f"参考向量文件不存在: {vector_path}")
134
+ return []
135
  ref_vector = np.load(vector_path)
136
+ print(f"参考向量加载成功: {vector_path}")
137
  return ref_vector
138
  except Exception as e:
139
+ print(f"参考向量加载失败 {vector_path}: {e}")
140
  return []
141
 
142
  # Load reference image
143
  def load_reference_images(ref_dir):
144
  try:
145
+ if not os.path.exists(ref_dir):
146
+ print(f"参考图像目录不存在: {ref_dir}")
147
+ os.makedirs(ref_dir, exist_ok=True)
148
+ return []
149
+
150
  image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
151
  image_files = []
152
  for ext in image_extensions:
 
157
  img = cv2.imread(file)
158
  if img is not None:
159
  reference_images.append(img)
160
+ print(f" {ref_dir} 加载了 {len(reference_images)} 张图像")
161
  return reference_images
162
  except Exception as e:
163
+ print(f"加载图像失败 {ref_dir}: {e}")
164
  return []
165
 
166
  # Preprocess the image
 
233
  if model is None:
234
  return None, None, f"错误:无法加载模型", "未检测", None
235
 
236
+ ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location])
237
  ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location])
238
 
239
  outlier_status = "未检测"
 
243
  if len(ref_vector) > 0:
244
  filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector)
245
  is_outlier = len(filtered) == 0
246
+ elif len(ref_images) > 0:
247
  filtered, _ = detect_outliers(ref_images, [image_bgr])
248
  is_outlier = len(filtered) == 0
249
+ else:
250
+ print("警告:没有参考图像或参考向量可用于异常检测")
251
+ is_outlier = False
252
 
253
  outlier_status = "异常检测: <span style='color:red;font-weight:bold'>未通过</span>" if is_outlier else "异常检测: <span style='color:green;font-weight:bold'>通过</span>"
254
  seg_map, overlay, analysis = perform_segmentation(model, image_bgr)
255
 
256
+ # 尝试上传到DigitalOcean Spaces
257
  url = "本地存储"
258
+ try:
259
+ url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_'))
260
+ except Exception as e:
261
+ print(f"上传失败: {e}")
262
+ url = f"上传错误: {str(e)}"
 
263
 
264
  if is_outlier:
265
  analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>警告:图像未通过异常检测,结果可能不准确!</div>" + analysis
 
280
  ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR)
281
  tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
282
 
283
+ try:
284
+ aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)])
285
+ aligned_tgt_bgr = aligned[1]
286
+ except Exception as e:
287
+ print(f"空间对齐失败: {e}")
288
+ return None, None, None, None, f"空间对齐失败: {str(e)}", "处理失败", None
289
 
290
  seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr)
291
 
292
+ # 尝试上传到DigitalOcean Spaces
293
  url = "本地存储"
294
+ try:
295
+ url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_'))
296
+ except Exception as e:
297
+ print(f"上传失败: {e}")
298
+ url = f"上传错误: {str(e)}"
 
299
 
300
  status = "空间对齐: <span style='color:green;font-weight:bold'>完成</span>"
301
  ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB)
 
324
  url1 = gr.Text(label="分割图URL")
325
  status1 = gr.HTML(label="异常检测状态")
326
  res1 = gr.HTML(label="分析结果")
327
+ btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1])
328
 
329
  with gr.TabItem("空间对齐分割"):
330
  with gr.Row():
 
340
  with gr.Row():
341
  seg2 = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h)
342
  ovl2 = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h)
343
+ url2 = gr.Text(label="分割图URL")
344
  status2 = gr.HTML(label="空间对齐状态")
345
  res2 = gr.HTML(label="分析结果")
346
  btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2])
347
  return demo
348
 
349
  if __name__ == "__main__":
350
+ # 创建必要的目录
351
  for path in ["models", "reference_images/MM", "reference_images/SJ"]:
352
  os.makedirs(path, exist_ok=True)
353
+
354
+ # 检查模型文件是否存在
355
  for p in MODEL_PATHS.values():
356
  if not os.path.exists(p):
357
  print(f"警告:模型文件 {p} 不存在!")
358
+
359
+ # 检查DigitalOcean凭据是否存在
360
+ do_creds = [
361
+ os.environ.get('DO_SPACES_KEY'),
362
+ os.environ.get('DO_SPACES_SECRET'),
363
+ os.environ.get('DO_SPACES_REGION'),
364
+ os.environ.get('DO_SPACES_BUCKET')
365
+ ]
366
+ if not all(do_creds):
367
+ print("警告:DigitalOcean Spaces凭据不完整,上传功能可能不可用")
368
+
369
+ # 创建并启动界面
370
  demo = create_interface()
371
  # 在HF环境中使用适当的启动配置
372
  if HF_SPACE:
models/put model ADDED
File without changes
pipeline/HSV.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def preprocess_images(images, V_FIXED = 200):
5
+ fixed_images = []
6
+ for image in images:
7
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
8
+
9
+ hsv_fixed = hsv.copy()
10
+ hsv_fixed[:, :, 2] = (hsv[:, :, 2] / hsv[:, :, 2].max()) * V_FIXED
11
+ hsv_fixed[:, :, 1] = hsv_fixed[:, :, 1] * (hsv_fixed[:, :, 2] / hsv[:, :, 2].max())
12
+ hsv_fixed[:, :, 1] = np.clip(hsv_fixed[:, :, 1], 0, 255)
13
+
14
+ fixed_image = cv2.cvtColor(hsv_fixed, cv2.COLOR_HSV2BGR)
15
+ fixed_images.append(fixed_image)
16
+ return fixed_images
pipeline/ImgOutlier.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision
4
+ from PIL import Image
5
+ from torch import nn
6
+ from torchvision import transforms as tr
7
+ from torchvision.models import vit_h_14
8
+ import cv2
9
+
10
+ class CosineSimilarity:
11
+ def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None):
12
+ """
13
+ Initialize the CosineSimilarity class.
14
+
15
+ Args:
16
+ vector (str): Type of vector to use ('feature' or 'image')
17
+ threshold (float): Threshold for determining outliers
18
+ mean_vec (numpy vector): Preloaded reference vector for comparison
19
+ device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu')
20
+ """
21
+ if device is None:
22
+ if torch.backends.mps.is_available():
23
+ self.device = 'mps'
24
+ elif torch.cuda.is_available():
25
+ self.device = 'cuda'
26
+ else:
27
+ self.device = 'cpu'
28
+ else:
29
+ self.device = device
30
+
31
+ self.vector = vector
32
+ self.threshold = threshold
33
+ self.model_instance = None
34
+ self.mean_vec = mean_vec
35
+
36
+ def model(self):
37
+ """Initialize and return the ViT model."""
38
+ if self.model_instance is None:
39
+ wt = torchvision.models.ViT_H_14_Weights.DEFAULT
40
+ self.model_instance = vit_h_14(weights=wt)
41
+ self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1])
42
+ self.model_instance = self.model_instance.to(self.device)
43
+ return self.model_instance
44
+
45
+ def process_image(self, cv2_img):
46
+ """
47
+ Process a cv2 image for the model.
48
+
49
+ Args:
50
+ cv2_img: OpenCV image (BGR format)
51
+
52
+ Returns:
53
+ Processed tensor
54
+ """
55
+ # Convert BGR to RGB
56
+ rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
57
+ # Convert to PIL Image
58
+ pil_img = Image.fromarray(rgb_img)
59
+
60
+ # A set of transformations to prepare the image in tensor format
61
+ transformations = tr.Compose([
62
+ tr.ToTensor(),
63
+ tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
64
+ tr.Resize((518, 518))
65
+ ])
66
+
67
+ # preparing the image
68
+ img_tensor = transformations(pil_img).float()
69
+
70
+ if self.vector == 'image':
71
+ img_tensor = img_tensor.flatten()
72
+
73
+ img_tensor = img_tensor.unsqueeze_(0)
74
+
75
+ if self.vector == 'feature':
76
+ img_tensor = img_tensor.to(self.device)
77
+
78
+ return img_tensor
79
+
80
+ def get_embeddings(self, ref_images, test_images):
81
+ """
82
+ Get embeddings for reference and test images.
83
+
84
+ Args:
85
+ ref_images: List of cv2 reference images
86
+ test_images: List of cv2 test images
87
+
88
+ Returns:
89
+ Reference embedding, list of test embeddings
90
+ """
91
+ model = self.model()
92
+
93
+ # Process test images
94
+ emb_test = []
95
+ for img in test_images:
96
+ processed_img = self.process_image(img)
97
+ if self.vector == 'feature':
98
+ emb = model(processed_img).detach().cpu()
99
+ emb_test.append(emb)
100
+ else: # 'image'
101
+ emb_test.append(processed_img)
102
+
103
+ # This checks if a reference vector is loaded, if so the process of getting
104
+ # reference embeddings can be skipped for efficiency
105
+ if len(self.mean_vec) > 0:
106
+ emb_ref = torch.tensor(self.mean_vec)
107
+
108
+ # Process reference images if necessary
109
+ else:
110
+ if self.vector == 'feature':
111
+ # Standard method of getting reference embedding vector
112
+ emb_ref_list = []
113
+ for img in ref_images:
114
+ processed_img = self.process_image(img)
115
+ emb = model(processed_img).detach().cpu()
116
+ emb_ref_list.append(emb)
117
+
118
+ # Average the reference embeddings
119
+ emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
120
+
121
+ else: # 'image'
122
+ emb_ref_list = []
123
+ for img in ref_images:
124
+ processed_img = self.process_image(img)
125
+ emb_ref_list.append(processed_img)
126
+
127
+ # Average the reference images
128
+ emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
129
+
130
+ return emb_ref, emb_test
131
+
132
+ def find_outliers(self, ref_images, test_images):
133
+ """
134
+ Find outliers in test images compared to reference images.
135
+
136
+ Args:
137
+ ref_images: List of cv2 reference images
138
+ test_images: List of cv2 test images
139
+
140
+ Returns:
141
+ mask: Boolean array where True indicates an outlier
142
+ scores: Similarity scores for each test image
143
+ """
144
+ emb_ref, emb_test = self.get_embeddings(ref_images, test_images)
145
+
146
+ scores = []
147
+ mask = []
148
+
149
+ for i in range(len(emb_test)):
150
+ score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i])
151
+ score_value = score.item()
152
+ scores.append(round(score_value, 4))
153
+ # True if it's an outlier (below threshold)
154
+ mask.append(score_value <= self.threshold)
155
+
156
+ return np.array(mask), scores, emb_ref
157
+
158
+ def filter_outliers(self, ref_images, test_images):
159
+ """
160
+ Filter out outliers from test images.
161
+
162
+ Args:
163
+ ref_images: List of cv2 reference images
164
+ test_images: List of cv2 test images
165
+
166
+ Returns:
167
+ filtered_images: List of non-outlier test images
168
+ outlier_mask: Boolean array where True indicates an outlier
169
+ scores: Similarity scores for each test image
170
+ """
171
+ outlier_mask, scores, mean = self.find_outliers(ref_images, test_images)
172
+
173
+ # Filter out outliers (keep only non-outliers)
174
+ filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]]
175
+
176
+ return filtered_images, outlier_mask, scores, mean
177
+
178
+ def detect_outliers(ref_imgs, imgs, mean_vec=[]):
179
+ """
180
+ Detects outliers in a set of test images, can use a reference vector
181
+
182
+ Args:
183
+ ref_images: List of cv2 reference images
184
+ images: List of cv2 test images
185
+ mean_vec: optional pre-computed reference vector
186
+
187
+ Returns:
188
+ filtered_images: List of non-outlier test images
189
+ mean: the reference vector used (if a new reference vector should be saved)
190
+ """
191
+
192
+ similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec)
193
+
194
+ # Get outlier mask, scores, and reference vector
195
+ outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs)
196
+
197
+ # Filter out outliers
198
+ filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]]
199
+
200
+ return filtered_images, mean_vector
pipeline/normalization.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModel
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+
6
+
7
+ def align_images(images, segs):
8
+ """
9
+ Align images using SuperGlue for feature matching.
10
+
11
+ Args:
12
+ images: List of input images
13
+ segs: List of segmentation images
14
+
15
+ Returns:
16
+ Tuple of (aligned images, aligned segmentation images)
17
+ """
18
+ if not images or len(images) < 2:
19
+ return images, segs
20
+
21
+ reference = images[0]
22
+ reference_seg = segs[0]
23
+ aligned_images = [reference]
24
+ aligned_images_seg = [reference_seg]
25
+
26
+ # Load SuperGlue model and processor
27
+ processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
28
+ model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
29
+
30
+ for i in range(1, len(images)):
31
+ current = images[i]
32
+ current_seg = segs[i]
33
+
34
+ # Process image pair
35
+ image_pair = [reference, current]
36
+ inputs = processor(image_pair, return_tensors="pt")
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ # Get matches
42
+ image_sizes = [[(img.shape[0], img.shape[1]) for img in image_pair]]
43
+ matches = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
44
+
45
+ # Extract matching keypoints
46
+ match_data = matches[0]
47
+ keypoints0 = match_data["keypoints0"].numpy()
48
+ keypoints1 = match_data["keypoints1"].numpy()
49
+
50
+ # Filter matches by confidence
51
+ valid_matches = match_data["matching_scores"] > 0.5
52
+ if sum(valid_matches) < 4:
53
+ print(f"Not enough confident matches for image {i}, keeping original")
54
+ aligned_images.append(current)
55
+ aligned_images_seg.append(current_seg)
56
+ continue
57
+
58
+ # Get matching points
59
+ src_pts = keypoints1[valid_matches].reshape(-1, 1, 2)
60
+ dst_pts = keypoints0[valid_matches].reshape(-1, 1, 2)
61
+
62
+ # Find homography
63
+ H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
64
+
65
+ if H is not None:
66
+ # Apply homography
67
+ h, w = reference.shape[:2]
68
+ aligned = cv2.warpPerspective(current, H, (w, h))
69
+ aligned_images.append(aligned)
70
+ aligned_seg = cv2.warpPerspective(current_seg, H, (w, h))
71
+ aligned_images_seg.append(aligned_seg)
72
+ else:
73
+ print(f"Could not find homography for image {i}, keeping original")
74
+ aligned_images.append(current)
75
+ aligned_images_seg.append(current_seg)
76
+
77
+ return aligned_images, aligned_images_seg
reference_images/MM/New Text Document.txt ADDED
File without changes
reference_images/SJ/New Text Document.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ segmentation_models_pytorch
4
+ gradio
5
+ opencv-python
6
+ numpy
7
+ pillow
8
+ transformers
9
+ boto3
uploader/__init__.py ADDED
File without changes
uploader/do_spaces.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import boto3
3
+ from botocore.client import Config
4
+ import uuid
5
+ from PIL import Image
6
+ import io
7
+
8
+ def upload_mask(image, prefix="mask"):
9
+ """
10
+ 将分割掩码图像上传到DigitalOcean Spaces
11
+
12
+ Args:
13
+ image: PIL Image对象
14
+ prefix: 文件名前缀
15
+
16
+ Returns:
17
+ 上传文件的URL
18
+ """
19
+ try:
20
+ # 从环境变量获取凭据
21
+ do_key = os.environ.get('DO_SPACES_KEY')
22
+ do_secret = os.environ.get('DO_SPACES_SECRET')
23
+ do_region = os.environ.get('DO_SPACES_REGION')
24
+ do_bucket = os.environ.get('DO_SPACES_BUCKET')
25
+
26
+ # 校验凭据是否存在
27
+ if not all([do_key, do_secret, do_region, do_bucket]):
28
+ raise ValueError("缺少DigitalOcean Spaces凭据")
29
+
30
+ # 创建S3客户端
31
+ session = boto3.session.Session()
32
+ client = session.client('s3',
33
+ region_name=do_region,
34
+ endpoint_url=f'https://{do_region}.digitaloceanspaces.com',
35
+ aws_access_key_id=do_key,
36
+ aws_secret_access_key=do_secret)
37
+
38
+ # 生成唯一文件名
39
+ filename = f"{prefix}_{uuid.uuid4().hex}.png"
40
+
41
+ # 将图像转换为字节流
42
+ img_byte_arr = io.BytesIO()
43
+ image.save(img_byte_arr, format='PNG')
44
+ img_byte_arr.seek(0)
45
+
46
+ # 上传到Spaces
47
+ client.upload_fileobj(
48
+ img_byte_arr,
49
+ do_bucket,
50
+ filename,
51
+ ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'}
52
+ )
53
+
54
+ # 返回公共URL
55
+ url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}'
56
+ return url
57
+
58
+ except Exception as e:
59
+ print(f"上传失败: {str(e)}")
60
+ return f"上传错误: {str(e)}"