Spaces:
Sleeping
Sleeping
Gent (PG/R - Comp Sci & Elec Eng)
commited on
Commit
·
460258c
1
Parent(s):
a52e395
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import vision_transformer as models
|
4 |
+
import cv2
|
5 |
+
from torch import nn
|
6 |
+
from utils import load_pretrained_weights
|
7 |
+
|
8 |
+
|
9 |
+
class PatchEmbedding:
|
10 |
+
"""
|
11 |
+
该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。
|
12 |
+
Args:
|
13 |
+
pretrained_weights (str): 预训练权重文件的路径。
|
14 |
+
arch (str, optional): 模型使用的体系结构。默认为“vit_base”。
|
15 |
+
patch_size (int, optional): 图像中提取的patch的大小。默认值为16。
|
16 |
+
Attributes:
|
17 |
+
model: 图像嵌入模型。
|
18 |
+
embed_dim (int): 图像嵌入的维度。
|
19 |
+
Methods:
|
20 |
+
load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。
|
21 |
+
get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。
|
22 |
+
"""
|
23 |
+
def __init__(self, pretrained_weights, arch='vit_base', patch_size=16):
|
24 |
+
self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0)
|
25 |
+
self.embed_dim = self.model.embed_dim
|
26 |
+
self.model.eval().requires_grad_(False)
|
27 |
+
self.load_pretrained_weights(pretrained_weights)
|
28 |
+
|
29 |
+
from torchvision import transforms
|
30 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
31 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
32 |
+
|
33 |
+
|
34 |
+
self.tfms = transforms.Compose([
|
35 |
+
transforms.ToTensor(),
|
36 |
+
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
37 |
+
])
|
38 |
+
|
39 |
+
|
40 |
+
def load_pretrained_weights(self, pretrained_weights):
|
41 |
+
load_pretrained_weights(self.model, pretrained_weights)
|
42 |
+
|
43 |
+
def get_representation(self, image):
|
44 |
+
"""
|
45 |
+
生成输入图像的patch token。
|
46 |
+
Args:
|
47 |
+
image_path (str): 输入图像的路径。
|
48 |
+
Returns:
|
49 |
+
patch_tokens (ndarray): 表示生成的patch token的数组: N, C。
|
50 |
+
"""
|
51 |
+
img = self.tfms(image)
|
52 |
+
x = img[None,:]
|
53 |
+
tokens = self.model.forward_features(x)[0] # N - 1, C
|
54 |
+
tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy()
|
55 |
+
cls_token = tokens[0] # C
|
56 |
+
patch_tokens = tokens[1:] # N - 1, C
|
57 |
+
return cls_token, patch_tokens
|
58 |
+
|
59 |
+
def __call__(self, x):
|
60 |
+
return self.model.forward_features(x)
|
61 |
+
|
62 |
+
default_shape = (224,224)
|
63 |
+
embedding = PatchEmbedding('weights/mmc.pth')
|
64 |
+
|
65 |
+
|
66 |
+
def classify(query_image, support_image):
|
67 |
+
# Your classification code here
|
68 |
+
q_cls = embedding.get_representation(query_image)[0]
|
69 |
+
s_cls = embedding.get_representation(support_image)[0]
|
70 |
+
|
71 |
+
sim = (q_cls*s_cls).sum()*100
|
72 |
+
return f"{sim:.2f}"
|
73 |
+
|
74 |
+
def segment(threshold, input):
|
75 |
+
# Your segmentation code here
|
76 |
+
image = input['image']
|
77 |
+
mask = input['mask']
|
78 |
+
|
79 |
+
patch_tokens = embedding.get_representation(image)[1]
|
80 |
+
select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten()
|
81 |
+
q_pat = patch_tokens[select].mean(0) # C
|
82 |
+
sim = patch_tokens @ q_pat[:,None] # N,1
|
83 |
+
|
84 |
+
mask = (sim.reshape(14,14) > threshold).astype('float')
|
85 |
+
mask = cv2.resize(mask,(224,224))
|
86 |
+
ans = image*mask[:,:,None]
|
87 |
+
return ans.astype('uint8')
|
88 |
+
|
89 |
+
classification_tab = gr.Interface(
|
90 |
+
fn=classify,
|
91 |
+
inputs=[
|
92 |
+
# gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
|
93 |
+
gr.inputs.Image(label="Query Image",shape=default_shape),
|
94 |
+
gr.inputs.Image(label="Support Image",shape=default_shape)
|
95 |
+
],
|
96 |
+
outputs=gr.outputs.Textbox(label="Prediction"),
|
97 |
+
title="Classification"
|
98 |
+
)
|
99 |
+
|
100 |
+
segmentation_tab = gr.Interface(
|
101 |
+
fn=segment,
|
102 |
+
inputs=[
|
103 |
+
gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
|
104 |
+
gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape)
|
105 |
+
],
|
106 |
+
outputs=gr.outputs.Image('numpy',label='Segmentation'),
|
107 |
+
title="Segmentation"
|
108 |
+
)
|
109 |
+
|
110 |
+
interface = gr.TabbedInterface(
|
111 |
+
[classification_tab, segmentation_tab],
|
112 |
+
["Classification", "Segmentation"]
|
113 |
+
# layout="horizontal"
|
114 |
+
)
|
115 |
+
|
116 |
+
interface.launch()
|