Spaces:
Build error
Build error
File size: 3,820 Bytes
a4fb052 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
class HeadVQA(torch.nn.Module):
def __init__(self, train_config):
super().__init__()
embedding_size = {'RN50': 1024,
'RN101': 512,
'RN50x4': 640,
'RN50x16': 768,
'RN50x64': 1024,
'ViT-B/32': 512,
'ViT-B/16': 512,
'ViT-L/14': 768,
'ViT-L/14@336px': 768}
n_aux_classes = len(set(train_config.aux_mapping.values()))
self.ln1 = torch.nn.LayerNorm(embedding_size[train_config.model]*2)
self.dp1 = torch.nn.Dropout(0.5)
self.fc1 = torch.nn.Linear(embedding_size[train_config.model] * 2, 512)
self.ln2 = torch.nn.LayerNorm(512)
self.dp2 = torch.nn.Dropout(0.5)
self.fc2 = torch.nn.Linear(512, train_config.n_classes)
self.fc_aux = torch.nn.Linear(512, n_aux_classes)
self.fc_gate = torch.nn.Linear(n_aux_classes, train_config.n_classes)
self.act_gate = torch.nn.Sigmoid()
def forward(self, img_features, question_features):
xc = torch.cat((img_features, question_features), dim=-1)
x = self.ln1(xc)
x = self.dp1(x)
x = self.fc1(x)
aux = self.fc_aux(x)
gate = self.fc_gate(aux)
gate = self.act_gate(gate)
x = self.ln2(x)
x = self.dp2(x)
vqa = self.fc2(x)
output = vqa * gate
return output, aux
class NetVQA(torch.nn.Module):
def __init__(self, train_config):
super().__init__()
self.heads = torch.nn.ModuleList()
if isinstance(train_config.folds, list):
self.num_heads = len(train_config.folds)
else:
self.num_heads = train_config.folds
for i in range(self.num_heads):
self.heads.append(HeadVQA(train_config))
def forward(self, img_features, question_features):
output = []
output_aux = []
for head in self.heads:
logits, logits_aux = head(img_features, question_features)
probs = logits.softmax(-1)
probs_aux = logits_aux.softmax(-1)
output.append(probs)
output_aux.append(probs_aux)
output = torch.stack(output, dim=-1).mean(-1)
output_aux = torch.stack(output_aux, dim=-1).mean(-1)
return output, output_aux
def merge_vqa(train_config):
# Initialize model
model = NetVQA(train_config)
for fold in train_config.folds:
print("load weights from fold {} into head {}".format(fold, fold))
checkpoint_path = "{}/{}/fold_{}".format(train_config.model_path, train_config.model, fold)
if train_config.crossvalidation:
# load best checkpoint
model_state_dict = torch.load('{}/weights_best.pth'.format(checkpoint_path))
else:
# load checkpoint on train end
model_state_dict = torch.load('{}/weights_end.pth'.format(checkpoint_path))
model.heads[fold].load_state_dict(model_state_dict, strict=True)
checkpoint_path = "{}/{}/weights_merged.pth".format(train_config.model_path, train_config.model)
print("Saving weights of merged model:", checkpoint_path)
torch.save(model.state_dict(), checkpoint_path)
return model
|