File size: 3,820 Bytes
a4fb052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch

class HeadVQA(torch.nn.Module):
    def __init__(self, train_config):
        super().__init__()
        
        embedding_size = {'RN50': 1024,
                          'RN101': 512,
                          'RN50x4': 640,
                          'RN50x16': 768,
                          'RN50x64': 1024,
                          'ViT-B/32': 512,
                          'ViT-B/16': 512,
                          'ViT-L/14': 768,
                          'ViT-L/14@336px': 768}
        
        n_aux_classes = len(set(train_config.aux_mapping.values()))
        
        self.ln1 = torch.nn.LayerNorm(embedding_size[train_config.model]*2)
        self.dp1 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(embedding_size[train_config.model] * 2, 512)

        self.ln2 = torch.nn.LayerNorm(512)
        self.dp2 = torch.nn.Dropout(0.5)
        self.fc2 = torch.nn.Linear(512, train_config.n_classes)
        
        self.fc_aux = torch.nn.Linear(512, n_aux_classes)
        self.fc_gate = torch.nn.Linear(n_aux_classes, train_config.n_classes)
        self.act_gate = torch.nn.Sigmoid()
        

    def forward(self, img_features, question_features):
        xc = torch.cat((img_features, question_features), dim=-1) 
        
        x = self.ln1(xc)
        x = self.dp1(x)
        x = self.fc1(x)
        
        aux = self.fc_aux(x)

        gate = self.fc_gate(aux)
        gate = self.act_gate(gate)
         
        x = self.ln2(x)
        x = self.dp2(x)
        vqa = self.fc2(x)
        
        output = vqa * gate
        
        return output, aux
    
    
class NetVQA(torch.nn.Module):
    def __init__(self, train_config):
        super().__init__()
        
        self.heads = torch.nn.ModuleList()
        
        if isinstance(train_config.folds, list):
            self.num_heads = len(train_config.folds)
        else:
            self.num_heads = train_config.folds
        
        for i in range(self.num_heads):
            self.heads.append(HeadVQA(train_config))
        
        
    def forward(self, img_features, question_features):
        
        output = []
        output_aux = []
        
        for head in self.heads:
            
            logits, logits_aux = head(img_features, question_features)
            
            probs = logits.softmax(-1)
            probs_aux = logits_aux.softmax(-1)
            
            output.append(probs)
            output_aux.append(probs_aux)
             
        output = torch.stack(output, dim=-1).mean(-1)
        output_aux = torch.stack(output_aux, dim=-1).mean(-1)
                
        return output, output_aux  
      
def merge_vqa(train_config):
    
    # Initialize model
    model = NetVQA(train_config)


    for fold in train_config.folds:
        
        print("load weights from fold {} into head {}".format(fold, fold))
        
        checkpoint_path = "{}/{}/fold_{}".format(train_config.model_path, train_config.model, fold)
        
        if train_config.crossvalidation:
            # load best checkpoint 
            model_state_dict = torch.load('{}/weights_best.pth'.format(checkpoint_path))      
        else:
            # load checkpoint on train end
            model_state_dict = torch.load('{}/weights_end.pth'.format(checkpoint_path))    
             
        model.heads[fold].load_state_dict(model_state_dict, strict=True)
        
    checkpoint_path = "{}/{}/weights_merged.pth".format(train_config.model_path, train_config.model)

    print("Saving weights of merged model:", checkpoint_path)

    torch.save(model.state_dict(), checkpoint_path) 
    
    return model