Spaces:
Running
Running
init
Browse files- app.py +112 -32
- models/put model +0 -0
- pipeline/HSV.py +16 -0
- pipeline/ImgOutlier.py +200 -0
- pipeline/normalization.py +77 -0
- reference_images/MM/New Text Document.txt +0 -0
- reference_images/SJ/New Text Document.txt +0 -0
- requirements.txt +9 -0
- uploader/__init__.py +0 -0
- uploader/do_spaces.py +60 -0
app.py
CHANGED
@@ -5,6 +5,9 @@ import torch
|
|
5 |
import gradio as gr
|
6 |
import segmentation_models_pytorch as smp
|
7 |
from PIL import Image
|
|
|
|
|
|
|
8 |
from glob import glob
|
9 |
from pipeline.ImgOutlier import detect_outliers
|
10 |
from pipeline.normalization import align_images
|
@@ -12,13 +15,60 @@ from pipeline.normalization import align_images
|
|
12 |
# 检测是否在Hugging Face环境中运行
|
13 |
HF_SPACE = os.environ.get('SPACE_ID') is not None
|
14 |
|
15 |
-
#
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
try:
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Global Configuration
|
24 |
MODEL_PATHS = {
|
@@ -54,6 +104,8 @@ def load_model(model_path, device="cuda"):
|
|
54 |
# 如果在HF环境中,默认使用CPU
|
55 |
if HF_SPACE:
|
56 |
device = "cpu" # HF Space可能没有GPU
|
|
|
|
|
57 |
|
58 |
model = smp.create_model(
|
59 |
"DeepLabV3Plus",
|
@@ -68,25 +120,33 @@ def load_model(model_path, device="cuda"):
|
|
68 |
model.load_state_dict(state_dict)
|
69 |
model.to(device)
|
70 |
model.eval()
|
71 |
-
print(f"
|
72 |
return model
|
73 |
except Exception as e:
|
74 |
-
print(f"
|
75 |
return None
|
76 |
|
77 |
# Load reference vector
|
78 |
def load_reference_vector(vector_path):
|
79 |
try:
|
|
|
|
|
|
|
80 |
ref_vector = np.load(vector_path)
|
81 |
-
print(f"
|
82 |
return ref_vector
|
83 |
except Exception as e:
|
84 |
-
print(f"
|
85 |
return []
|
86 |
|
87 |
# Load reference image
|
88 |
def load_reference_images(ref_dir):
|
89 |
try:
|
|
|
|
|
|
|
|
|
|
|
90 |
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
|
91 |
image_files = []
|
92 |
for ext in image_extensions:
|
@@ -97,10 +157,10 @@ def load_reference_images(ref_dir):
|
|
97 |
img = cv2.imread(file)
|
98 |
if img is not None:
|
99 |
reference_images.append(img)
|
100 |
-
print(f"
|
101 |
return reference_images
|
102 |
except Exception as e:
|
103 |
-
print(f"
|
104 |
return []
|
105 |
|
106 |
# Preprocess the image
|
@@ -173,7 +233,7 @@ def process_coastal_image(location, input_image):
|
|
173 |
if model is None:
|
174 |
return None, None, f"错误:无法加载模型", "未检测", None
|
175 |
|
176 |
-
ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location])
|
177 |
ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location])
|
178 |
|
179 |
outlier_status = "未检测"
|
@@ -183,21 +243,23 @@ def process_coastal_image(location, input_image):
|
|
183 |
if len(ref_vector) > 0:
|
184 |
filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector)
|
185 |
is_outlier = len(filtered) == 0
|
186 |
-
|
187 |
filtered, _ = detect_outliers(ref_images, [image_bgr])
|
188 |
is_outlier = len(filtered) == 0
|
|
|
|
|
|
|
189 |
|
190 |
outlier_status = "异常检测: <span style='color:red;font-weight:bold'>未通过</span>" if is_outlier else "异常检测: <span style='color:green;font-weight:bold'>通过</span>"
|
191 |
seg_map, overlay, analysis = perform_segmentation(model, image_bgr)
|
192 |
|
193 |
-
#
|
194 |
url = "本地存储"
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
url = "上传错误"
|
201 |
|
202 |
if is_outlier:
|
203 |
analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>警告:图像未通过异常检测,结果可能不准确!</div>" + analysis
|
@@ -218,19 +280,22 @@ def process_with_alignment(location, reference_image, input_image):
|
|
218 |
ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR)
|
219 |
tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
|
220 |
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
223 |
|
224 |
seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr)
|
225 |
|
226 |
-
#
|
227 |
url = "本地存储"
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
url = "上传错误"
|
234 |
|
235 |
status = "空间对齐: <span style='color:green;font-weight:bold'>完成</span>"
|
236 |
ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB)
|
@@ -259,7 +324,7 @@ def create_interface():
|
|
259 |
url1 = gr.Text(label="分割图URL")
|
260 |
status1 = gr.HTML(label="异常检测状态")
|
261 |
res1 = gr.HTML(label="分析结果")
|
262 |
-
btn1.click(fn=process_coastal_image,inputs=[loc1, inp],outputs=[seg, ovl, res1, status1, url1])
|
263 |
|
264 |
with gr.TabItem("空间对齐分割"):
|
265 |
with gr.Row():
|
@@ -275,18 +340,33 @@ def create_interface():
|
|
275 |
with gr.Row():
|
276 |
seg2 = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h)
|
277 |
ovl2 = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h)
|
278 |
-
|
279 |
status2 = gr.HTML(label="空间对齐状态")
|
280 |
res2 = gr.HTML(label="分析结果")
|
281 |
btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2])
|
282 |
return demo
|
283 |
|
284 |
if __name__ == "__main__":
|
|
|
285 |
for path in ["models", "reference_images/MM", "reference_images/SJ"]:
|
286 |
os.makedirs(path, exist_ok=True)
|
|
|
|
|
287 |
for p in MODEL_PATHS.values():
|
288 |
if not os.path.exists(p):
|
289 |
print(f"警告:模型文件 {p} 不存在!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
demo = create_interface()
|
291 |
# 在HF环境中使用适当的启动配置
|
292 |
if HF_SPACE:
|
|
|
5 |
import gradio as gr
|
6 |
import segmentation_models_pytorch as smp
|
7 |
from PIL import Image
|
8 |
+
import boto3
|
9 |
+
import uuid
|
10 |
+
import io
|
11 |
from glob import glob
|
12 |
from pipeline.ImgOutlier import detect_outliers
|
13 |
from pipeline.normalization import align_images
|
|
|
15 |
# 检测是否在Hugging Face环境中运行
|
16 |
HF_SPACE = os.environ.get('SPACE_ID') is not None
|
17 |
|
18 |
+
# DigitalOcean Spaces上传函数
|
19 |
+
def upload_mask(image, prefix="mask"):
|
20 |
+
"""
|
21 |
+
将分割掩码图像上传到DigitalOcean Spaces
|
22 |
+
|
23 |
+
Args:
|
24 |
+
image: PIL Image对象
|
25 |
+
prefix: 文件名前缀
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
上传文件的URL
|
29 |
+
"""
|
30 |
try:
|
31 |
+
# 从环境变量获取凭据
|
32 |
+
do_key = os.environ.get('DO_SPACES_KEY')
|
33 |
+
do_secret = os.environ.get('DO_SPACES_SECRET')
|
34 |
+
do_region = os.environ.get('DO_SPACES_REGION')
|
35 |
+
do_bucket = os.environ.get('DO_SPACES_BUCKET')
|
36 |
+
|
37 |
+
# 校验凭据是否存在
|
38 |
+
if not all([do_key, do_secret, do_region, do_bucket]):
|
39 |
+
return "DigitalOcean凭据未设置"
|
40 |
+
|
41 |
+
# 创建S3客户端
|
42 |
+
session = boto3.session.Session()
|
43 |
+
client = session.client('s3',
|
44 |
+
region_name=do_region,
|
45 |
+
endpoint_url=f'https://{do_region}.digitaloceanspaces.com',
|
46 |
+
aws_access_key_id=do_key,
|
47 |
+
aws_secret_access_key=do_secret)
|
48 |
+
|
49 |
+
# 生成唯一文件名
|
50 |
+
filename = f"{prefix}_{uuid.uuid4().hex}.png"
|
51 |
+
|
52 |
+
# 将图像转换为字节流
|
53 |
+
img_byte_arr = io.BytesIO()
|
54 |
+
image.save(img_byte_arr, format='PNG')
|
55 |
+
img_byte_arr.seek(0)
|
56 |
+
|
57 |
+
# 上传到Spaces
|
58 |
+
client.upload_fileobj(
|
59 |
+
img_byte_arr,
|
60 |
+
do_bucket,
|
61 |
+
filename,
|
62 |
+
ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'}
|
63 |
+
)
|
64 |
+
|
65 |
+
# 返回公共URL
|
66 |
+
url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}'
|
67 |
+
return url
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
print(f"上传失败: {str(e)}")
|
71 |
+
return f"上传错误: {str(e)}"
|
72 |
|
73 |
# Global Configuration
|
74 |
MODEL_PATHS = {
|
|
|
104 |
# 如果在HF环境中,默认使用CPU
|
105 |
if HF_SPACE:
|
106 |
device = "cpu" # HF Space可能没有GPU
|
107 |
+
elif not torch.cuda.is_available():
|
108 |
+
device = "cpu" # 本地环境也可能没有GPU
|
109 |
|
110 |
model = smp.create_model(
|
111 |
"DeepLabV3Plus",
|
|
|
120 |
model.load_state_dict(state_dict)
|
121 |
model.to(device)
|
122 |
model.eval()
|
123 |
+
print(f"模型加载成功: {model_path}")
|
124 |
return model
|
125 |
except Exception as e:
|
126 |
+
print(f"模型加载失败: {e}")
|
127 |
return None
|
128 |
|
129 |
# Load reference vector
|
130 |
def load_reference_vector(vector_path):
|
131 |
try:
|
132 |
+
if not os.path.exists(vector_path):
|
133 |
+
print(f"参考向量文件不存在: {vector_path}")
|
134 |
+
return []
|
135 |
ref_vector = np.load(vector_path)
|
136 |
+
print(f"参考向量加载成功: {vector_path}")
|
137 |
return ref_vector
|
138 |
except Exception as e:
|
139 |
+
print(f"参考向量加载失败 {vector_path}: {e}")
|
140 |
return []
|
141 |
|
142 |
# Load reference image
|
143 |
def load_reference_images(ref_dir):
|
144 |
try:
|
145 |
+
if not os.path.exists(ref_dir):
|
146 |
+
print(f"参考图像目录不存在: {ref_dir}")
|
147 |
+
os.makedirs(ref_dir, exist_ok=True)
|
148 |
+
return []
|
149 |
+
|
150 |
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
|
151 |
image_files = []
|
152 |
for ext in image_extensions:
|
|
|
157 |
img = cv2.imread(file)
|
158 |
if img is not None:
|
159 |
reference_images.append(img)
|
160 |
+
print(f"从 {ref_dir} 加载了 {len(reference_images)} 张图像")
|
161 |
return reference_images
|
162 |
except Exception as e:
|
163 |
+
print(f"加载图像失败 {ref_dir}: {e}")
|
164 |
return []
|
165 |
|
166 |
# Preprocess the image
|
|
|
233 |
if model is None:
|
234 |
return None, None, f"错误:无法加载模型", "未检测", None
|
235 |
|
236 |
+
ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location])
|
237 |
ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location])
|
238 |
|
239 |
outlier_status = "未检测"
|
|
|
243 |
if len(ref_vector) > 0:
|
244 |
filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector)
|
245 |
is_outlier = len(filtered) == 0
|
246 |
+
elif len(ref_images) > 0:
|
247 |
filtered, _ = detect_outliers(ref_images, [image_bgr])
|
248 |
is_outlier = len(filtered) == 0
|
249 |
+
else:
|
250 |
+
print("警告:没有参考图像或参考向量可用于异常检测")
|
251 |
+
is_outlier = False
|
252 |
|
253 |
outlier_status = "异常检测: <span style='color:red;font-weight:bold'>未通过</span>" if is_outlier else "异常检测: <span style='color:green;font-weight:bold'>通过</span>"
|
254 |
seg_map, overlay, analysis = perform_segmentation(model, image_bgr)
|
255 |
|
256 |
+
# 尝试上传到DigitalOcean Spaces
|
257 |
url = "本地存储"
|
258 |
+
try:
|
259 |
+
url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_'))
|
260 |
+
except Exception as e:
|
261 |
+
print(f"上传失败: {e}")
|
262 |
+
url = f"上传错误: {str(e)}"
|
|
|
263 |
|
264 |
if is_outlier:
|
265 |
analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>警告:图像未通过异常检测,结果可能不准确!</div>" + analysis
|
|
|
280 |
ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR)
|
281 |
tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
|
282 |
|
283 |
+
try:
|
284 |
+
aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)])
|
285 |
+
aligned_tgt_bgr = aligned[1]
|
286 |
+
except Exception as e:
|
287 |
+
print(f"空间对齐失败: {e}")
|
288 |
+
return None, None, None, None, f"空间对齐失败: {str(e)}", "处理失败", None
|
289 |
|
290 |
seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr)
|
291 |
|
292 |
+
# 尝试上传到DigitalOcean Spaces
|
293 |
url = "本地存储"
|
294 |
+
try:
|
295 |
+
url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_'))
|
296 |
+
except Exception as e:
|
297 |
+
print(f"上传失败: {e}")
|
298 |
+
url = f"上传错误: {str(e)}"
|
|
|
299 |
|
300 |
status = "空间对齐: <span style='color:green;font-weight:bold'>完成</span>"
|
301 |
ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB)
|
|
|
324 |
url1 = gr.Text(label="分割图URL")
|
325 |
status1 = gr.HTML(label="异常检测状态")
|
326 |
res1 = gr.HTML(label="分析结果")
|
327 |
+
btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1])
|
328 |
|
329 |
with gr.TabItem("空间对齐分割"):
|
330 |
with gr.Row():
|
|
|
340 |
with gr.Row():
|
341 |
seg2 = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h)
|
342 |
ovl2 = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h)
|
343 |
+
url2 = gr.Text(label="分割图URL")
|
344 |
status2 = gr.HTML(label="空间对齐状态")
|
345 |
res2 = gr.HTML(label="分析结果")
|
346 |
btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2])
|
347 |
return demo
|
348 |
|
349 |
if __name__ == "__main__":
|
350 |
+
# 创建必要的目录
|
351 |
for path in ["models", "reference_images/MM", "reference_images/SJ"]:
|
352 |
os.makedirs(path, exist_ok=True)
|
353 |
+
|
354 |
+
# 检查模型文件是否存在
|
355 |
for p in MODEL_PATHS.values():
|
356 |
if not os.path.exists(p):
|
357 |
print(f"警告:模型文件 {p} 不存在!")
|
358 |
+
|
359 |
+
# 检查DigitalOcean凭据是否存在
|
360 |
+
do_creds = [
|
361 |
+
os.environ.get('DO_SPACES_KEY'),
|
362 |
+
os.environ.get('DO_SPACES_SECRET'),
|
363 |
+
os.environ.get('DO_SPACES_REGION'),
|
364 |
+
os.environ.get('DO_SPACES_BUCKET')
|
365 |
+
]
|
366 |
+
if not all(do_creds):
|
367 |
+
print("警告:DigitalOcean Spaces凭据不完整,上传功能可能不可用")
|
368 |
+
|
369 |
+
# 创建并启动界面
|
370 |
demo = create_interface()
|
371 |
# 在HF环境中使用适当的启动配置
|
372 |
if HF_SPACE:
|
models/put model
ADDED
File without changes
|
pipeline/HSV.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def preprocess_images(images, V_FIXED = 200):
|
5 |
+
fixed_images = []
|
6 |
+
for image in images:
|
7 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
8 |
+
|
9 |
+
hsv_fixed = hsv.copy()
|
10 |
+
hsv_fixed[:, :, 2] = (hsv[:, :, 2] / hsv[:, :, 2].max()) * V_FIXED
|
11 |
+
hsv_fixed[:, :, 1] = hsv_fixed[:, :, 1] * (hsv_fixed[:, :, 2] / hsv[:, :, 2].max())
|
12 |
+
hsv_fixed[:, :, 1] = np.clip(hsv_fixed[:, :, 1], 0, 255)
|
13 |
+
|
14 |
+
fixed_image = cv2.cvtColor(hsv_fixed, cv2.COLOR_HSV2BGR)
|
15 |
+
fixed_images.append(fixed_image)
|
16 |
+
return fixed_images
|
pipeline/ImgOutlier.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
from PIL import Image
|
5 |
+
from torch import nn
|
6 |
+
from torchvision import transforms as tr
|
7 |
+
from torchvision.models import vit_h_14
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
class CosineSimilarity:
|
11 |
+
def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None):
|
12 |
+
"""
|
13 |
+
Initialize the CosineSimilarity class.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
vector (str): Type of vector to use ('feature' or 'image')
|
17 |
+
threshold (float): Threshold for determining outliers
|
18 |
+
mean_vec (numpy vector): Preloaded reference vector for comparison
|
19 |
+
device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu')
|
20 |
+
"""
|
21 |
+
if device is None:
|
22 |
+
if torch.backends.mps.is_available():
|
23 |
+
self.device = 'mps'
|
24 |
+
elif torch.cuda.is_available():
|
25 |
+
self.device = 'cuda'
|
26 |
+
else:
|
27 |
+
self.device = 'cpu'
|
28 |
+
else:
|
29 |
+
self.device = device
|
30 |
+
|
31 |
+
self.vector = vector
|
32 |
+
self.threshold = threshold
|
33 |
+
self.model_instance = None
|
34 |
+
self.mean_vec = mean_vec
|
35 |
+
|
36 |
+
def model(self):
|
37 |
+
"""Initialize and return the ViT model."""
|
38 |
+
if self.model_instance is None:
|
39 |
+
wt = torchvision.models.ViT_H_14_Weights.DEFAULT
|
40 |
+
self.model_instance = vit_h_14(weights=wt)
|
41 |
+
self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1])
|
42 |
+
self.model_instance = self.model_instance.to(self.device)
|
43 |
+
return self.model_instance
|
44 |
+
|
45 |
+
def process_image(self, cv2_img):
|
46 |
+
"""
|
47 |
+
Process a cv2 image for the model.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
cv2_img: OpenCV image (BGR format)
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Processed tensor
|
54 |
+
"""
|
55 |
+
# Convert BGR to RGB
|
56 |
+
rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
|
57 |
+
# Convert to PIL Image
|
58 |
+
pil_img = Image.fromarray(rgb_img)
|
59 |
+
|
60 |
+
# A set of transformations to prepare the image in tensor format
|
61 |
+
transformations = tr.Compose([
|
62 |
+
tr.ToTensor(),
|
63 |
+
tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
64 |
+
tr.Resize((518, 518))
|
65 |
+
])
|
66 |
+
|
67 |
+
# preparing the image
|
68 |
+
img_tensor = transformations(pil_img).float()
|
69 |
+
|
70 |
+
if self.vector == 'image':
|
71 |
+
img_tensor = img_tensor.flatten()
|
72 |
+
|
73 |
+
img_tensor = img_tensor.unsqueeze_(0)
|
74 |
+
|
75 |
+
if self.vector == 'feature':
|
76 |
+
img_tensor = img_tensor.to(self.device)
|
77 |
+
|
78 |
+
return img_tensor
|
79 |
+
|
80 |
+
def get_embeddings(self, ref_images, test_images):
|
81 |
+
"""
|
82 |
+
Get embeddings for reference and test images.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
ref_images: List of cv2 reference images
|
86 |
+
test_images: List of cv2 test images
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Reference embedding, list of test embeddings
|
90 |
+
"""
|
91 |
+
model = self.model()
|
92 |
+
|
93 |
+
# Process test images
|
94 |
+
emb_test = []
|
95 |
+
for img in test_images:
|
96 |
+
processed_img = self.process_image(img)
|
97 |
+
if self.vector == 'feature':
|
98 |
+
emb = model(processed_img).detach().cpu()
|
99 |
+
emb_test.append(emb)
|
100 |
+
else: # 'image'
|
101 |
+
emb_test.append(processed_img)
|
102 |
+
|
103 |
+
# This checks if a reference vector is loaded, if so the process of getting
|
104 |
+
# reference embeddings can be skipped for efficiency
|
105 |
+
if len(self.mean_vec) > 0:
|
106 |
+
emb_ref = torch.tensor(self.mean_vec)
|
107 |
+
|
108 |
+
# Process reference images if necessary
|
109 |
+
else:
|
110 |
+
if self.vector == 'feature':
|
111 |
+
# Standard method of getting reference embedding vector
|
112 |
+
emb_ref_list = []
|
113 |
+
for img in ref_images:
|
114 |
+
processed_img = self.process_image(img)
|
115 |
+
emb = model(processed_img).detach().cpu()
|
116 |
+
emb_ref_list.append(emb)
|
117 |
+
|
118 |
+
# Average the reference embeddings
|
119 |
+
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
|
120 |
+
|
121 |
+
else: # 'image'
|
122 |
+
emb_ref_list = []
|
123 |
+
for img in ref_images:
|
124 |
+
processed_img = self.process_image(img)
|
125 |
+
emb_ref_list.append(processed_img)
|
126 |
+
|
127 |
+
# Average the reference images
|
128 |
+
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
|
129 |
+
|
130 |
+
return emb_ref, emb_test
|
131 |
+
|
132 |
+
def find_outliers(self, ref_images, test_images):
|
133 |
+
"""
|
134 |
+
Find outliers in test images compared to reference images.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
ref_images: List of cv2 reference images
|
138 |
+
test_images: List of cv2 test images
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
mask: Boolean array where True indicates an outlier
|
142 |
+
scores: Similarity scores for each test image
|
143 |
+
"""
|
144 |
+
emb_ref, emb_test = self.get_embeddings(ref_images, test_images)
|
145 |
+
|
146 |
+
scores = []
|
147 |
+
mask = []
|
148 |
+
|
149 |
+
for i in range(len(emb_test)):
|
150 |
+
score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i])
|
151 |
+
score_value = score.item()
|
152 |
+
scores.append(round(score_value, 4))
|
153 |
+
# True if it's an outlier (below threshold)
|
154 |
+
mask.append(score_value <= self.threshold)
|
155 |
+
|
156 |
+
return np.array(mask), scores, emb_ref
|
157 |
+
|
158 |
+
def filter_outliers(self, ref_images, test_images):
|
159 |
+
"""
|
160 |
+
Filter out outliers from test images.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
ref_images: List of cv2 reference images
|
164 |
+
test_images: List of cv2 test images
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
filtered_images: List of non-outlier test images
|
168 |
+
outlier_mask: Boolean array where True indicates an outlier
|
169 |
+
scores: Similarity scores for each test image
|
170 |
+
"""
|
171 |
+
outlier_mask, scores, mean = self.find_outliers(ref_images, test_images)
|
172 |
+
|
173 |
+
# Filter out outliers (keep only non-outliers)
|
174 |
+
filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]]
|
175 |
+
|
176 |
+
return filtered_images, outlier_mask, scores, mean
|
177 |
+
|
178 |
+
def detect_outliers(ref_imgs, imgs, mean_vec=[]):
|
179 |
+
"""
|
180 |
+
Detects outliers in a set of test images, can use a reference vector
|
181 |
+
|
182 |
+
Args:
|
183 |
+
ref_images: List of cv2 reference images
|
184 |
+
images: List of cv2 test images
|
185 |
+
mean_vec: optional pre-computed reference vector
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
filtered_images: List of non-outlier test images
|
189 |
+
mean: the reference vector used (if a new reference vector should be saved)
|
190 |
+
"""
|
191 |
+
|
192 |
+
similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec)
|
193 |
+
|
194 |
+
# Get outlier mask, scores, and reference vector
|
195 |
+
outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs)
|
196 |
+
|
197 |
+
# Filter out outliers
|
198 |
+
filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]]
|
199 |
+
|
200 |
+
return filtered_images, mean_vector
|
pipeline/normalization.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoImageProcessor, AutoModel
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
def align_images(images, segs):
|
8 |
+
"""
|
9 |
+
Align images using SuperGlue for feature matching.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
images: List of input images
|
13 |
+
segs: List of segmentation images
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
Tuple of (aligned images, aligned segmentation images)
|
17 |
+
"""
|
18 |
+
if not images or len(images) < 2:
|
19 |
+
return images, segs
|
20 |
+
|
21 |
+
reference = images[0]
|
22 |
+
reference_seg = segs[0]
|
23 |
+
aligned_images = [reference]
|
24 |
+
aligned_images_seg = [reference_seg]
|
25 |
+
|
26 |
+
# Load SuperGlue model and processor
|
27 |
+
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
28 |
+
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
|
29 |
+
|
30 |
+
for i in range(1, len(images)):
|
31 |
+
current = images[i]
|
32 |
+
current_seg = segs[i]
|
33 |
+
|
34 |
+
# Process image pair
|
35 |
+
image_pair = [reference, current]
|
36 |
+
inputs = processor(image_pair, return_tensors="pt")
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
outputs = model(**inputs)
|
40 |
+
|
41 |
+
# Get matches
|
42 |
+
image_sizes = [[(img.shape[0], img.shape[1]) for img in image_pair]]
|
43 |
+
matches = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
44 |
+
|
45 |
+
# Extract matching keypoints
|
46 |
+
match_data = matches[0]
|
47 |
+
keypoints0 = match_data["keypoints0"].numpy()
|
48 |
+
keypoints1 = match_data["keypoints1"].numpy()
|
49 |
+
|
50 |
+
# Filter matches by confidence
|
51 |
+
valid_matches = match_data["matching_scores"] > 0.5
|
52 |
+
if sum(valid_matches) < 4:
|
53 |
+
print(f"Not enough confident matches for image {i}, keeping original")
|
54 |
+
aligned_images.append(current)
|
55 |
+
aligned_images_seg.append(current_seg)
|
56 |
+
continue
|
57 |
+
|
58 |
+
# Get matching points
|
59 |
+
src_pts = keypoints1[valid_matches].reshape(-1, 1, 2)
|
60 |
+
dst_pts = keypoints0[valid_matches].reshape(-1, 1, 2)
|
61 |
+
|
62 |
+
# Find homography
|
63 |
+
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
64 |
+
|
65 |
+
if H is not None:
|
66 |
+
# Apply homography
|
67 |
+
h, w = reference.shape[:2]
|
68 |
+
aligned = cv2.warpPerspective(current, H, (w, h))
|
69 |
+
aligned_images.append(aligned)
|
70 |
+
aligned_seg = cv2.warpPerspective(current_seg, H, (w, h))
|
71 |
+
aligned_images_seg.append(aligned_seg)
|
72 |
+
else:
|
73 |
+
print(f"Could not find homography for image {i}, keeping original")
|
74 |
+
aligned_images.append(current)
|
75 |
+
aligned_images_seg.append(current_seg)
|
76 |
+
|
77 |
+
return aligned_images, aligned_images_seg
|
reference_images/MM/New Text Document.txt
ADDED
File without changes
|
reference_images/SJ/New Text Document.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
segmentation_models_pytorch
|
4 |
+
gradio
|
5 |
+
opencv-python
|
6 |
+
numpy
|
7 |
+
pillow
|
8 |
+
transformers
|
9 |
+
boto3
|
uploader/__init__.py
ADDED
File without changes
|
uploader/do_spaces.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import boto3
|
3 |
+
from botocore.client import Config
|
4 |
+
import uuid
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
|
8 |
+
def upload_mask(image, prefix="mask"):
|
9 |
+
"""
|
10 |
+
将分割掩码图像上传到DigitalOcean Spaces
|
11 |
+
|
12 |
+
Args:
|
13 |
+
image: PIL Image对象
|
14 |
+
prefix: 文件名前缀
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
上传文件的URL
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
# 从环境变量获取凭据
|
21 |
+
do_key = os.environ.get('DO_SPACES_KEY')
|
22 |
+
do_secret = os.environ.get('DO_SPACES_SECRET')
|
23 |
+
do_region = os.environ.get('DO_SPACES_REGION')
|
24 |
+
do_bucket = os.environ.get('DO_SPACES_BUCKET')
|
25 |
+
|
26 |
+
# 校验凭据是否存在
|
27 |
+
if not all([do_key, do_secret, do_region, do_bucket]):
|
28 |
+
raise ValueError("缺少DigitalOcean Spaces凭据")
|
29 |
+
|
30 |
+
# 创建S3客户端
|
31 |
+
session = boto3.session.Session()
|
32 |
+
client = session.client('s3',
|
33 |
+
region_name=do_region,
|
34 |
+
endpoint_url=f'https://{do_region}.digitaloceanspaces.com',
|
35 |
+
aws_access_key_id=do_key,
|
36 |
+
aws_secret_access_key=do_secret)
|
37 |
+
|
38 |
+
# 生成唯一文件名
|
39 |
+
filename = f"{prefix}_{uuid.uuid4().hex}.png"
|
40 |
+
|
41 |
+
# 将图像转换为字节流
|
42 |
+
img_byte_arr = io.BytesIO()
|
43 |
+
image.save(img_byte_arr, format='PNG')
|
44 |
+
img_byte_arr.seek(0)
|
45 |
+
|
46 |
+
# 上传到Spaces
|
47 |
+
client.upload_fileobj(
|
48 |
+
img_byte_arr,
|
49 |
+
do_bucket,
|
50 |
+
filename,
|
51 |
+
ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'}
|
52 |
+
)
|
53 |
+
|
54 |
+
# 返回公共URL
|
55 |
+
url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}'
|
56 |
+
return url
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
print(f"上传失败: {str(e)}")
|
60 |
+
return f"上传错误: {str(e)}"
|