PromptNet / modules /visual_extractor.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from medclip import MedCLIPModel, MedCLIPVisionModelViT
from medclip import MedCLIPProcessor
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class VisualExtractor(nn.Module):
# prepare for the demo image and text
def __init__(self, args):
super(VisualExtractor, self).__init__()
self.model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
self.model.from_pretrained()
self.model.cuda()
self.processor = MedCLIPProcessor()
with torch.no_grad():
self.prompt = torch.load('prompt/prompt.pth')
def forward(self, images):
a=[]
for i in images:
inputs = self.processor( text="lungs",images=i,return_tensors="pt",padding=True)
outputs = self.model(**inputs)
feats = outputs['img_embeds']
a.append(feats)
batch_feats = torch.stack(a, dim=0)
ha = []
for i in range(batch_feats.shape[0]):
b = batch_feats[i].unsqueeze(1)
b = b.repeat(self.prompt.shape[0], 1, 1).transpose(-2, -1)
c_t = torch.bmm(self.prompt, b)
c_t = c_t.float()
alpha = F.softmax(c_t)
aa = alpha * self.prompt
sum_a = aa.sum(axis=0)
ha.append(sum_a)
featsem = torch.stack(ha, dim=0)
feats = torch.cat((featsem, batch_feats), dim=2)
patch_feats = feats.repeat(1, 49, 1)
batch_feats1 = feats.squeeze(1)
avg_feats = batch_feats1
return patch_feats, avg_feats