File size: 1,672 Bytes
6e32a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

from medclip import MedCLIPModel, MedCLIPVisionModelViT
from medclip import MedCLIPProcessor
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class VisualExtractor(nn.Module):
    # prepare for the demo image and text
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
        self.model.from_pretrained()
        self.model.cuda()
        self.processor = MedCLIPProcessor()
        with torch.no_grad():
            self.prompt = torch.load('prompt/prompt.pth')


    def forward(self, images):
        a=[]
        for i in images:
            inputs = self.processor( text="lungs",images=i,return_tensors="pt",padding=True)
            outputs = self.model(**inputs)
            feats = outputs['img_embeds']
            a.append(feats)      
        batch_feats = torch.stack(a, dim=0)

        ha = []
        for i in range(batch_feats.shape[0]):
            b = batch_feats[i].unsqueeze(1)
            b = b.repeat(self.prompt.shape[0], 1, 1).transpose(-2, -1)
            c_t = torch.bmm(self.prompt, b)
            c_t = c_t.float()
            alpha = F.softmax(c_t)
            aa = alpha * self.prompt
            sum_a = aa.sum(axis=0)
            ha.append(sum_a)
        featsem = torch.stack(ha, dim=0)

        feats = torch.cat((featsem, batch_feats), dim=2)

        patch_feats = feats.repeat(1, 49, 1)
        batch_feats1 = feats.squeeze(1)
        avg_feats = batch_feats1


        return patch_feats, avg_feats