File size: 5,492 Bytes
6e32a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import torch
import torch.nn as nn
import pickle
from typing import Tuple
from transformers import GPT2LMHeadModel
from modules.decoder import DeCap
from medclip import MedCLIPModel, MedCLIPVisionModelViT
import math
import pdb


class MedCapModel(nn.Module):
    def __init__(self, args, tokenizer):
        super(MedCapModel, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.tokenizer = tokenizer
        self.model = DeCap(args, tokenizer)

        self.align_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
        self.align_model.from_pretrained()
        self.prompt = torch.load(args.prompt)
        if args.dataset == 'iu_xray':
            self.forward = self.forward_iu_xray
        else:
            self.forward = self.forward_mimic_cxr

    def noise_injection(self, x, variance=0.001, modality_offset=None, dont_norm=False):
        if variance == 0.0:
            return x
        std = math.sqrt(variance)
        if not dont_norm:
            x = torch.nn.functional.normalize(x, dim=1)
        else:
            x = x + (torch.randn(x.shape) * std)  # todo by some conventions multivraiance noise should be devided by sqrt of dim
        if modality_offset is not None:
            x = x + modality_offset
        return torch.nn.functional.normalize(x, dim=1)

    def align_encode_images_iu_xray(self, images):
        # Split the images
        image1, image2 = images.unbind(dim=1)
        # Encode each image
        feature1 = self.align_model.encode_image(image1)
        feature2 = self.align_model.encode_image(image2)
        if self.args.prompt_load == 'yes':
            sim_1 = feature1 @ self.prompt.T.float()
            sim_1 = (sim_1 * 100).softmax(dim=-1)
            prefix_embedding_1 = sim_1 @ self.prompt.float()
            prefix_embedding_1 /= prefix_embedding_1.norm(dim=-1, keepdim=True)

            sim_2 = feature2 @ self.prompt.T.float()
            sim_2 = (sim_2 * 100).softmax(dim=-1)
            prefix_embedding_2 = sim_2 @ self.prompt.float()
            prefix_embedding_2 /= prefix_embedding_2.norm(dim=-1, keepdim=True)
            averaged_prompt_features = torch.mean(torch.stack([prefix_embedding_1, prefix_embedding_2]), dim=0)
            return averaged_prompt_features
        else:
            # Concatenate the features
            averaged_features = torch.mean(torch.stack([feature1, feature2]), dim=0)
            return averaged_features

    def align_encode_images_mimic_cxr(self, images):
        feature = self.align_model.encode_image(images)
        if self.args.prompt_load == 'yes':
            sim = feature @ self.prompt.T.float()
            sim = (sim * 100).softmax(dim=-1)
            prefix_embedding = sim @ self.prompt.float()
            prefix_embedding /= prefix_embedding.norm(dim=-1, keepdim=True)
            return prefix_embedding
        else:
            return feature

    def forward_iu_xray(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}):
        self.align_model.to(self.device)
        self.align_model.eval()
        align_ids = align_ids.long()

        align_image_feature = None
        if self.args.train_mode == 'fine-tuning':
            align_image_feature = self.align_encode_images_iu_xray(images)
        if mode == 'train':
            align_text_feature = self.align_model.encode_text(align_ids, align_masks)
            if self.args.noise_inject == 'yes':
                align_text_feature = self.noise_injection(align_text_feature)

            if self.args.train_mode == 'fine-tuning':
                if self.args.F_version == 'v1':
                    combined_feature = torch.cat([align_text_feature, align_image_feature], dim=-1)
                    align_text_feature = self.fc_reduce_dim(combined_feature)
                if self.args.F_version == 'v2':
                    align_text_feature = align_image_feature

            outputs = self.model(align_text_feature, reports_ids, mode='forward')
            logits = outputs.logits
            logits = logits[:, :-1]
            return logits
        elif mode == 'sample':
            align_image_feature = self.align_encode_images_iu_xray(images)
            outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts)
            return outputs
        else:
            raise ValueError

    def forward_mimic_cxr(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}):
        self.align_model.to(self.device)
        self.align_model.eval()
        align_ids = align_ids.long()
        if mode == 'train':
            if self.args.noise_inject == 'yes':
                align_text_feature = self.align_model.encode_text(align_ids, align_masks)
                align_text_feature = self.noise_injection(align_text_feature)
            else:
                align_text_feature = self.align_model.encode_text(align_ids, align_masks)
            outputs = self.model(align_text_feature, reports_ids, mode='forward')
            logits = outputs.logits
            logits = logits[:, :-1]
            return logits
        elif mode == 'sample':
            align_image_feature = self.align_encode_images_mimic_cxr(images)
            outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts)
            return outputs
        else:
            raise ValueError