File size: 4,352 Bytes
460258c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f6b3f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460258c
0f6b3f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import vision_transformer as models
import cv2
from torch import nn
from utils import load_pretrained_weights


class PatchEmbedding:
    """
    该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。
     Args:
        pretrained_weights (str): 预训练权重文件的路径。
        arch (str, optional): 模型使用的体系结构。默认为“vit_base”。
        patch_size (int, optional): 图像中提取的patch的大小。默认值为16。
     Attributes:
        model: 图像嵌入模型。
        embed_dim (int): 图像嵌入的维度。
     Methods:
        load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。
        get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。
    """
    def __init__(self, pretrained_weights, arch='vit_base', patch_size=16):
        self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0)
        self.embed_dim = self.model.embed_dim
        self.model.eval().requires_grad_(False)
        self.load_pretrained_weights(pretrained_weights)
        
        from torchvision import transforms
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
 

        self.tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ])
        
        
    def load_pretrained_weights(self, pretrained_weights):
        load_pretrained_weights(self.model, pretrained_weights)
        
    def get_representation(self, image):
        """
        生成输入图像的patch token。
         Args:
            image_path (str): 输入图像的路径。
         Returns:
            patch_tokens (ndarray): 表示生成的patch token的数组: N, C。
         """
        img = self.tfms(image)
        x = img[None,:]
        tokens = self.model.forward_features(x)[0] # N - 1, C
        tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy()
        cls_token = tokens[0] # C
        patch_tokens = tokens[1:] # N - 1, C
        return cls_token, patch_tokens
        
    def __call__(self, x):
        return self.model.forward_features(x)
    
default_shape = (224,224)
embedding = PatchEmbedding('weights/mmc.pth')


def classify(query_image, support_image):
    # Your classification code here
    q_cls = embedding.get_representation(query_image)[0]
    s_cls = embedding.get_representation(support_image)[0]
    
    sim = (q_cls*s_cls).sum()*100
    return f"{sim:.2f}"

def segment(threshold, input):
    # Your segmentation code here
    image = input['image']
    mask = input['mask']
    
    patch_tokens = embedding.get_representation(image)[1]
    select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten()
    q_pat = patch_tokens[select].mean(0) # C
    sim = patch_tokens @ q_pat[:,None] # N,1
    
    mask = (sim.reshape(14,14) > threshold).astype('float')
    mask = cv2.resize(mask,(224,224))
    ans = image*mask[:,:,None]
    return ans.astype('uint8')

classification_tab = gr.Interface(
    fn=classify,
    inputs=[
        # gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
        gr.inputs.Image(label="Query Image",shape=default_shape),
        gr.inputs.Image(label="Support Image",shape=default_shape)
    ],
    outputs=gr.outputs.Textbox(label="Prediction"),
    title="Classification"
)

segmentation_tab = gr.Interface(
    fn=segment,
    inputs=[
        gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
        gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape)
    ],
    outputs=gr.outputs.Image('numpy',label='Segmentation'),
    title="Segmentation"
)

with gr.Blocks() as app:
    gr.Markdown("""
@misc{wu2023masked,
      title={Masked Momentum Contrastive Learning for Zero-shot Semantic Understanding}, 
      author={Jiantao Wu and Shentong Mo and Muhammad Awais and Sara Atito and Zhenhua Feng and Josef Kittler},
      year={2023},
      eprint={2308.11448},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}""")
    interface = gr.TabbedInterface(
        [classification_tab, segmentation_tab],
        ["Classification", "Segmentation"]
        # layout="horizontal"
    )

app.launch()