Gent (PG/R - Comp Sci & Elec Eng) commited on
Commit
460258c
·
1 Parent(s): a52e395

Add application file

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import vision_transformer as models
4
+ import cv2
5
+ from torch import nn
6
+ from utils import load_pretrained_weights
7
+
8
+
9
+ class PatchEmbedding:
10
+ """
11
+ 该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。
12
+ Args:
13
+ pretrained_weights (str): 预训练权重文件的路径。
14
+ arch (str, optional): 模型使用的体系结构。默认为“vit_base”。
15
+ patch_size (int, optional): 图像中提取的patch的大小。默认值为16。
16
+ Attributes:
17
+ model: 图像嵌入模型。
18
+ embed_dim (int): 图像嵌入的维度。
19
+ Methods:
20
+ load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。
21
+ get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。
22
+ """
23
+ def __init__(self, pretrained_weights, arch='vit_base', patch_size=16):
24
+ self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0)
25
+ self.embed_dim = self.model.embed_dim
26
+ self.model.eval().requires_grad_(False)
27
+ self.load_pretrained_weights(pretrained_weights)
28
+
29
+ from torchvision import transforms
30
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
31
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
32
+
33
+
34
+ self.tfms = transforms.Compose([
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
37
+ ])
38
+
39
+
40
+ def load_pretrained_weights(self, pretrained_weights):
41
+ load_pretrained_weights(self.model, pretrained_weights)
42
+
43
+ def get_representation(self, image):
44
+ """
45
+ 生成输入图像的patch token。
46
+ Args:
47
+ image_path (str): 输入图像的路径。
48
+ Returns:
49
+ patch_tokens (ndarray): 表示生成的patch token的数组: N, C。
50
+ """
51
+ img = self.tfms(image)
52
+ x = img[None,:]
53
+ tokens = self.model.forward_features(x)[0] # N - 1, C
54
+ tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy()
55
+ cls_token = tokens[0] # C
56
+ patch_tokens = tokens[1:] # N - 1, C
57
+ return cls_token, patch_tokens
58
+
59
+ def __call__(self, x):
60
+ return self.model.forward_features(x)
61
+
62
+ default_shape = (224,224)
63
+ embedding = PatchEmbedding('weights/mmc.pth')
64
+
65
+
66
+ def classify(query_image, support_image):
67
+ # Your classification code here
68
+ q_cls = embedding.get_representation(query_image)[0]
69
+ s_cls = embedding.get_representation(support_image)[0]
70
+
71
+ sim = (q_cls*s_cls).sum()*100
72
+ return f"{sim:.2f}"
73
+
74
+ def segment(threshold, input):
75
+ # Your segmentation code here
76
+ image = input['image']
77
+ mask = input['mask']
78
+
79
+ patch_tokens = embedding.get_representation(image)[1]
80
+ select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten()
81
+ q_pat = patch_tokens[select].mean(0) # C
82
+ sim = patch_tokens @ q_pat[:,None] # N,1
83
+
84
+ mask = (sim.reshape(14,14) > threshold).astype('float')
85
+ mask = cv2.resize(mask,(224,224))
86
+ ans = image*mask[:,:,None]
87
+ return ans.astype('uint8')
88
+
89
+ classification_tab = gr.Interface(
90
+ fn=classify,
91
+ inputs=[
92
+ # gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
93
+ gr.inputs.Image(label="Query Image",shape=default_shape),
94
+ gr.inputs.Image(label="Support Image",shape=default_shape)
95
+ ],
96
+ outputs=gr.outputs.Textbox(label="Prediction"),
97
+ title="Classification"
98
+ )
99
+
100
+ segmentation_tab = gr.Interface(
101
+ fn=segment,
102
+ inputs=[
103
+ gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
104
+ gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape)
105
+ ],
106
+ outputs=gr.outputs.Image('numpy',label='Segmentation'),
107
+ title="Segmentation"
108
+ )
109
+
110
+ interface = gr.TabbedInterface(
111
+ [classification_tab, segmentation_tab],
112
+ ["Classification", "Segmentation"]
113
+ # layout="horizontal"
114
+ )
115
+
116
+ interface.launch()