Spaces:
Sleeping
Sleeping
import os | |
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' | |
from medclip import MedCLIPModel, MedCLIPVisionModelViT | |
from medclip import MedCLIPProcessor | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torch.nn.functional as F | |
class VisualExtractor(nn.Module): | |
# prepare for the demo image and text | |
def __init__(self, args): | |
super(VisualExtractor, self).__init__() | |
self.model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT) | |
self.model.from_pretrained() | |
self.model.cuda() | |
self.processor = MedCLIPProcessor() | |
with torch.no_grad(): | |
self.prompt = torch.load('prompt/prompt.pth') | |
def forward(self, images): | |
a=[] | |
for i in images: | |
inputs = self.processor( text="lungs",images=i,return_tensors="pt",padding=True) | |
outputs = self.model(**inputs) | |
feats = outputs['img_embeds'] | |
a.append(feats) | |
batch_feats = torch.stack(a, dim=0) | |
ha = [] | |
for i in range(batch_feats.shape[0]): | |
b = batch_feats[i].unsqueeze(1) | |
b = b.repeat(self.prompt.shape[0], 1, 1).transpose(-2, -1) | |
c_t = torch.bmm(self.prompt, b) | |
c_t = c_t.float() | |
alpha = F.softmax(c_t) | |
aa = alpha * self.prompt | |
sum_a = aa.sum(axis=0) | |
ha.append(sum_a) | |
featsem = torch.stack(ha, dim=0) | |
feats = torch.cat((featsem, batch_feats), dim=2) | |
patch_feats = feats.repeat(1, 49, 1) | |
batch_feats1 = feats.squeeze(1) | |
avg_feats = batch_feats1 | |
return patch_feats, avg_feats |