from .critic_objectives import* from torchvision import transforms import torch.nn.functional as F import torch.optim as optim import math import copy ############################ # Simple Augmentations # ############################ def permute(x): # shuffle the sequence order idx = torch.randperm(x.shape[0]) return x[idx] def noise(x): noise = torch.randn(x.shape) * 0.1 return x + noise.to(x.device) def drop(x): # drop 20% of the sequences drop_num = x.shape[0] // 5 x_aug = torch.clone(x) drop_idxs = np.random.choice(x.shape[0], drop_num, replace=False) x_aug[drop_idxs] = 0.0 return x_aug def mixup(x, alpha=1.0): indices = torch.randperm(x.shape[0]) lam = np.random.beta(alpha, alpha) aug_x = x * lam + x[indices] * (1 - lam) return aug_x def identity(x): return x def augment(x_batch): v1 = x_batch v2 = torch.clone(v1) transforms = [permute, noise, drop, identity] for i in range(x_batch.shape[0]): t_idxs = np.random.choice(4, 2, replace=False) t1 = transforms[t_idxs[0]] t2 = transforms[t_idxs[1]] v1[i] = t1(v1[i]) v2[i] = t2(v2[i]) return v1, v2 # return one augmented instance def augment_single(x_batch): v1 = x_batch v2 = torch.clone(v1) transforms = [permute, noise, drop, identity] for i in range(x_batch.shape[0]): t_idxs = np.random.choice(4, 1, replace=False) t = transforms[t_idxs[0]] v2[i] = t(v2[i]) return v2 def augment_embed_single(x_batch): v1 = x_batch v2 = torch.clone(v1) transforms = [noise, mixup, identity] t_idxs = np.random.choice(3, 1, replace=False) t = transforms[t_idxs[0]] v2 = t(v2) return v2 def augment_mimic(x_batch): if x_batch.dim() == 2: return augment_embed_single(x_batch) else: return augment_single(x_batch) ############## # Models # ############## def mlp_head(dim_in, feat_dim): return nn.Sequential( nn.Linear(dim_in, dim_in), nn.ReLU(inplace=True), nn.Linear(dim_in, feat_dim) ) class SupConModel(nn.Module): """backbone + projection head""" def __init__(self, temperature, encoders, dim_ins, feat_dims, use_label=False, head='mlp'): super(SupConModel, self).__init__() self.use_label = use_label self.encoders = nn.ModuleList(encoders) if head == 'linear': self.head1 = nn.Linear(dim_ins[0], feat_dims[0]) self.head2 = nn.Linear(dim_ins[1], feat_dims[1]) elif head == 'mlp': self.head1 = nn.Sequential( nn.Linear(dim_ins[0], dim_ins[0]), nn.ReLU(inplace=True), nn.Linear(dim_ins[0], feat_dims[0]) ) self.head2 = nn.Sequential( nn.Linear(dim_ins[1], dim_ins[1]), nn.ReLU(inplace=True), nn.Linear(dim_ins[1], feat_dims[1]) ) else: raise NotImplementedError( 'head not supported: {}'.format(head)) self.critic = SupConLoss(temperature=temperature) def forward(self, x1, x2, y): feat1 = self.encoders[0](x1) feat1 = self.head1(feat1) feat2 = self.encoders[1](x2) feat2 = self.head2(feat2) feat = torch.cat([feat1.unsqueeze(1), feat2.unsqueeze(1)], dim=1) loss = self.critic(feat, y) if self.use_label else self.critic(feat) return loss def get_embedding(self, x1, x2): return self.encoders[0](x1), self.encoders[1](x2) class FactorCLSUP(nn.Module): def __init__(self, encoders, feat_dims, y_ohe_dim, temperature=1, activation='relu', lr=1e-4, ratio=1): super(FactorCLSUP, self).__init__() self.critic_hidden_dim = 512 self.critic_layers = 1 self.critic_activation = 'relu' self.lr = lr self.ratio = ratio self.y_ohe_dim = y_ohe_dim self.temperature = temperature self.feat_dims = feat_dims # encoder backbones ####self.backbones = nn.ModuleList(encoders) # linear projection heads ####self.linears_infonce_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) self.linears_club_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) ####self.linears_infonce_x1y = mlp_head(self.feat_dims[0], self.feat_dims[0]) ####self.linears_infonce_x2y = mlp_head(self.feat_dims[1], self.feat_dims[1]) ####self.linears_infonce_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) ####self.linears_club_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) # critics ####self.infonce_x1x2 = InfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) self.club_x1x2_cond = CLUBInfoNCECritic(self.feat_dims[0] + self.y_ohe_dim, self.feat_dims[1] + self.y_ohe_dim, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) ####self.infonce_x1y = InfoNCECritic(self.feat_dims[0], 1, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) ####self.infonce_x2y = InfoNCECritic(self.feat_dims[1], 1, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) ####self.infonce_x1x2_cond = InfoNCECritic(self.feat_dims[0] + self.y_ohe_dim, self.feat_dims[1] + self.y_ohe_dim, #### self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) ####self.club_x1x2 = CLUBInfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) def ohe(self, y): N = y.shape[0] y_ohe = torch.zeros((N, self.y_ohe_dim)) y_ohe[torch.arange(N).long(), y.T[0].long()] = 1 return y_ohe def forward(self, x1, x2, y): # Get embeddings ####x1_embed = self.backbones[0](x1) ####x2_embed = self.backbones[1](x2) x1_embed, x2_embed = x1, x2 x1_embed = F.normalize(x1_embed, dim=-1) x2_embed = F.normalize(x2_embed, dim=-1) # Get ohe label y_ohe = self.ohe(y).cuda() # Compute losses ####uncond_losses = [self.infonce_x1x2(self.linears_infonce_x1x2[0](x1_embed), self.linears_infonce_x1x2[1](x2_embed)), #### self.club_x1x2(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)), #### self.infonce_x1y(self.linears_infonce_x1y(x1_embed), y), #### self.infonce_x2y(self.linears_infonce_x2y(x2_embed), y) ####] #### ####cond_losses = [self.infonce_x1x2_cond(torch.cat([self.linears_infonce_x1x2_cond[0](x1_embed), y_ohe], dim=1), #### torch.cat([self.linears_infonce_x1x2_cond[1](x2_embed), y_ohe], dim=1)), #### self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), y_ohe], dim=1), #### torch.cat([self.linears_club_x1x2_cond[1](x2_embed), y_ohe], dim=1)), ####] cond_losses = [self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), y_ohe], dim=1), torch.cat([self.linears_club_x1x2_cond[1](x2_embed), y_ohe], dim=1)), ] ####return sum(uncond_losses) + sum(cond_losses) return sum(cond_losses) def learning_loss(self, x1, x2, y): # Get embeddings ####x1_embed = self.backbones[0](x1) ####x2_embed = self.backbones[1](x2) x1_embed, x2_embed = x1, x2 x1_embed = F.normalize(x1_embed, dim=-1) x2_embed = F.normalize(x2_embed, dim=-1) y_ohe = self.ohe(y).cuda() # Calculate InfoNCE loss for CLUB-NCE learning_losses = [self.club_x1x2.learning_loss(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)), self.club_x1x2_cond.learning_loss(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), y_ohe], dim=1), torch.cat([self.linears_club_x1x2_cond[1](x2_embed), y_ohe], dim=1)) ] return sum(learning_losses) def get_embedding(self, x1, x2): x1_embed = self.backbones[0](x1) x2_embed = self.backbones[1](x2) x1_reps = [self.linears_infonce_x1x2[0](x1_embed), self.linears_club_x1x2[0](x1_embed), self.linears_infonce_x1y(x1_embed), self.linears_infonce_x1x2_cond[0](x1_embed), self.linears_club_x1x2_cond[0](x1_embed)] x2_reps = [self.linears_infonce_x1x2[1](x2_embed), self.linears_club_x1x2[1](x2_embed), self.linears_infonce_x2y(x2_embed), self.linears_infonce_x1x2_cond[1](x2_embed), self.linears_club_x1x2_cond[1](x2_embed)] return torch.cat(x1_reps, dim=1), torch.cat(x2_reps, dim=1) def get_optims(self): non_CLUB_params = [self.backbones.parameters(), self.infonce_x1x2.parameters(), self.infonce_x1y.parameters(), self.infonce_x2y.parameters(), self.infonce_x1x2_cond.parameters(), self.linears_infonce_x1x2.parameters(), self.linears_infonce_x1y.parameters(), self.linears_infonce_x2y.parameters(), self.linears_infonce_x1x2_cond.parameters(), self.linears_club_x1x2_cond.parameters(), self.linears_club_x1x2.parameters()] CLUB_params = [self.club_x1x2_cond.parameters(), self.club_x1x2.parameters()] non_CLUB_optims = [optim.Adam(param, lr=self.lr) for param in non_CLUB_params] CLUB_optims = [optim.Adam(param, lr=self.lr) for param in CLUB_params] return non_CLUB_optims, CLUB_optims class FactorCLSSL(nn.Module): def __init__(self, encoders, feat_dims, y_ohe_dim, temperature=1, activation='relu', lr=1e-4, ratio=1): super(FactorCLSSL, self).__init__() self.critic_hidden_dim = 512 self.critic_layers = 1 self.critic_activation = 'relu' self.lr = lr self.ratio = ratio self.y_ohe_dim = y_ohe_dim self.temperature = temperature # encoder backbones self.feat_dims = feat_dims ####self.backbones = nn.ModuleList(encoders) # linear projection heads self.linears_infonce_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) self.linears_club_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) self.linears_infonce_x1y = mlp_head(self.feat_dims[0], self.feat_dims[0]) self.linears_infonce_x2y = mlp_head(self.feat_dims[1], self.feat_dims[1]) self.linears_infonce_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) self.linears_club_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)]) # critics self.infonce_x1x2 = InfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) self.club_x1x2_cond = CLUBInfoNCECritic(self.feat_dims[0]*2, self.feat_dims[1]*2, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) self.infonce_x1y = InfoNCECritic(self.feat_dims[0], self.feat_dims[0], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) self.infonce_x2y = InfoNCECritic(self.feat_dims[1], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) self.infonce_x1x2_cond = InfoNCECritic(self.feat_dims[0]*2, self.feat_dims[1]*2, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) self.club_x1x2 = CLUBInfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature) def ohe(self, y): N = y.shape[0] y_ohe = torch.zeros((N, self.y_ohe_dim)) y_ohe[torch.arange(N).long(), y.T[0].long()] = 1 return y_ohe def forward(self, x1, x2, x1_aug, x2_aug): # Get embeddings ####x1_embed = self.backbones[0](x1) ####x2_embed = self.backbones[1](x2) #### ####x1_aug_embed = self.backbones[0](x1_aug) ####x2_aug_embed = self.backbones[1](x2_aug) x1_embed, x2_embed, x1_aug_embed, x2_aug_embed = x1, x2, x1_aug, x2_aug x1_embed = F.normalize(x1_embed, dim=-1) x2_embed = F.normalize(x2_embed, dim=-1) x1_aug_embed = F.normalize(x1_aug_embed, dim=-1) x2_aug_embed = F.normalize(x2_aug_embed, dim=-1) #compute losses uncond_losses = [self.infonce_x1x2(self.linears_infonce_x1x2[0](x1_embed), self.linears_infonce_x1x2[1](x2_embed)), self.club_x1x2(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)), self.infonce_x1y(self.linears_infonce_x1y(x1_embed), self.linears_infonce_x1y(x1_aug_embed)), self.infonce_x2y(self.linears_infonce_x2y(x2_embed), self.linears_infonce_x2y(x2_aug_embed)) ] cond_losses = [self.infonce_x1x2_cond(torch.cat([self.linears_infonce_x1x2_cond[0](x1_embed), self.linears_infonce_x1x2_cond[0](x1_aug_embed)], dim=1), torch.cat([self.linears_infonce_x1x2_cond[1](x2_embed), self.linears_infonce_x1x2_cond[1](x2_aug_embed)], dim=1)), self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), self.linears_club_x1x2_cond[0](x1_aug_embed)], dim=1), torch.cat([self.linears_club_x1x2_cond[1](x2_embed), self.linears_club_x1x2_cond[1](x2_aug_embed)], dim=1)) ] return sum(uncond_losses) + sum(cond_losses) def learning_loss(self, x1, x2, x1_aug, x2_aug): # Get embeddings ####x1_embed = self.backbones[0](x1) ####x2_embed = self.backbones[1](x2) #### ####x1_aug_embed = self.backbones[0](x1_aug) ####x2_aug_embed = self.backbones[1](x2_aug) x1_embed, x2_embed, x1_aug_embed, x2_aug_embed = x1, x2, x1_aug, x2_aug x1_embed = F.normalize(x1_embed, dim=-1) x2_embed = F.normalize(x2_embed, dim=-1) x1_aug_embed = F.normalize(x1_aug_embed, dim=-1) x2_aug_embed = F.normalize(x2_aug_embed, dim=-1) # Calculate InfoNCE loss for CLUB-NCE learning_losses = [self.club_x1x2.learning_loss(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)), self.club_x1x2_cond.learning_loss(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), self.linears_club_x1x2_cond[0](x1_aug_embed)], dim=1), torch.cat([self.linears_club_x1x2_cond[1](x2_embed), self.linears_club_x1x2_cond[1](x2_aug_embed)], dim=1)) ] return sum(learning_losses) def get_embedding(self, x1, x2): x1_embed = self.backbones[0](x1) x2_embed = self.backbones[1](x2) x1_reps = [self.linears_infonce_x1x2[0](x1_embed), self.linears_club_x1x2[0](x1_embed), self.linears_infonce_x1y(x1_embed), self.linears_infonce_x1x2_cond[0](x1_embed), self.linears_club_x1x2_cond[0](x1_embed)] x2_reps = [self.linears_infonce_x1x2[1](x2_embed), self.linears_club_x1x2[1](x2_embed), self.linears_infonce_x2y(x2_embed), self.linears_infonce_x1x2_cond[1](x2_embed), self.linears_club_x1x2_cond[1](x2_embed)] return torch.cat(x1_reps, dim=1), torch.cat(x2_reps, dim=1) def get_optims(self): non_CLUB_params = [self.backbones.parameters(), self.infonce_x1x2.parameters(), self.infonce_x1y.parameters(), self.infonce_x2y.parameters(), self.infonce_x1x2_cond.parameters(), self.linears_infonce_x1x2.parameters(), self.linears_infonce_x1y.parameters(), self.linears_infonce_x2y.parameters(), self.linears_infonce_x1x2_cond.parameters(), self.linears_club_x1x2_cond.parameters(), self.linears_club_x1x2.parameters()] CLUB_params = [self.club_x1x2_cond.parameters(), self.club_x1x2.parameters()] non_CLUB_optims = [optim.Adam(param, lr=self.lr) for param in non_CLUB_params] CLUB_optims = [optim.Adam(param, lr=self.lr) for param in CLUB_params] return non_CLUB_optims, CLUB_optims ######################## # Training Scripts # ######################## # MOSI/MOSEI Training def mosi_label(y_batch): res = copy.deepcopy(y_batch) res[y_batch >= 0] = 1 res[y_batch < 0] = 0 return res def train_supcon_mosi(model, train_loader, optimizer, modalities=[0,2], num_epoch=100): for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0][modalities[0]].float().cuda() x2_batch = data_batch[0][modalities[1]].float().cuda() y_batch = mosi_label(data_batch[3]).float().cuda() loss = model(x1_batch, x2_batch, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return def train_sup_mosi(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1): non_CLUB_optims, CLUB_optims = model.get_optims() losses = [] for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0][modalities[0]].float().cuda() x2_batch = data_batch[0][modalities[1]].float().cuda() y_batch = mosi_label(data_batch[3]).float().cuda() loss = model(x1_batch, x2_batch, y_batch) losses.append(loss.detach().cpu().numpy()) for optimizer in non_CLUB_optims: optimizer.zero_grad() loss.backward() for optimizer in non_CLUB_optims: optimizer.step() for _ in range(num_club_iter): learning_loss = model.learning_loss(x1_batch, x2_batch, y_batch) for optimizer in CLUB_optims: optimizer.zero_grad() learning_loss.backward() for optimizer in CLUB_optims: optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return def train_ssl_mosi(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1): non_CLUB_optims, CLUB_optims = model.get_optims() losses = [] for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0][modalities[0]].float().cuda() x2_batch = data_batch[0][modalities[1]].float().cuda() x1_aug = augment_single(x1_batch) x2_aug = augment_single(x2_batch) loss = model(x1_batch, x2_batch, x1_aug, x2_aug) losses.append(loss.detach().cpu().numpy()) for optimizer in non_CLUB_optims: optimizer.zero_grad() loss.backward() for optimizer in non_CLUB_optims: optimizer.step() for _ in range(num_club_iter): learning_loss = model.learning_loss(x1_batch, x2_batch, x1_aug, x2_aug) for optimizer in CLUB_optims: optimizer.zero_grad() learning_loss.backward() for optimizer in CLUB_optims: optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return # Sarcasm/Humor Training def sarcasm_label(y_batch): res = copy.deepcopy(y_batch) res[y_batch == -1] = 0 return res def train_supcon_sarcasm(model, train_loader, optimizer, modalities=[0,2], num_epoch=100): for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0][modalities[0]].float().cuda() x2_batch = data_batch[0][modalities[1]].float().cuda() y_batch = sarcasm_label(data_batch[3]).float().cuda() loss = model(x1_batch, x2_batch, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return def train_sup_sarcasm(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1): non_CLUB_optims, CLUB_optims = model.get_optims() losses = [] for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0][modalities[0]].float().cuda() x2_batch = data_batch[0][modalities[1]].float().cuda() y_batch = sarcasm_label(data_batch[3]).float().cuda() #loss, losses, ts = model(x_batch, y_batch) loss = model(x1_batch, x2_batch, y_batch) losses.append(loss.detach().cpu().numpy()) for optimizer in non_CLUB_optims: optimizer.zero_grad() loss.backward() for optimizer in non_CLUB_optims: optimizer.step() for _ in range(num_club_iter): learning_loss = model.learning_loss(x1_batch, x2_batch, y_batch) for optimizer in CLUB_optims: optimizer.zero_grad() learning_loss.backward() for optimizer in CLUB_optims: optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return def train_ssl_sarcasm(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1): non_CLUB_optims, CLUB_optims = model.get_optims() losses = [] for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0][modalities[0]].float().cuda() x2_batch = data_batch[0][modalities[1]].float().cuda() x1_aug = augment_single(x1_batch) x2_aug = augment_single(x2_batch) loss = model(x1_batch, x2_batch, x1_aug, x2_aug) losses.append(loss.detach().cpu().numpy()) for optimizer in non_CLUB_optims: optimizer.zero_grad() loss.backward() for optimizer in non_CLUB_optims: optimizer.step() for _ in range(num_club_iter): learning_loss = model.learning_loss(x1_batch, x2_batch, x1_aug, x2_aug) for optimizer in CLUB_optims: optimizer.zero_grad() learning_loss.backward() for optimizer in CLUB_optims: optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return # MIMIC Training def train_supcon_mimic(model, train_loader, optimizer, num_epoch=100): for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0].float().cuda() x2_batch = data_batch[1].float().cuda() y_batch = data_batch[2].unsqueeze(0).T.float().cuda() loss = model(x1_batch, x2_batch, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return def train_sup_mimic(model, train_loader, num_epoch=50, num_club_iter=1): non_CLUB_optims, CLUB_optims = model.get_optims() losses = [] for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0].float().cuda() x2_batch = data_batch[1].float().cuda() y_batch = data_batch[2].unsqueeze(0).T.float().cuda() loss = model(x1_batch, x2_batch, y_batch) losses.append(loss.detach().cpu().numpy()) for optimizer in non_CLUB_optims: optimizer.zero_grad() loss.backward() for optimizer in non_CLUB_optims: optimizer.step() for _ in range(num_club_iter): learning_loss = model.learning_loss(x1_batch, x2_batch, y_batch) for optimizer in CLUB_optims: optimizer.zero_grad() learning_loss.backward() for optimizer in CLUB_optims: optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return def train_ssl_mimic(model, train_loader, num_epoch=50, num_club_iter=1): non_CLUB_optims, CLUB_optims = model.get_optims() losses = [] for _iter in range(num_epoch): for i_batch, data_batch in enumerate(train_loader): x1_batch = data_batch[0].float().cuda() x2_batch = data_batch[1].float().cuda() x1_aug = augment_mimic(x1_batch) x2_aug = augment_mimic(x2_batch) loss = model(x1_batch, x2_batch, x1_aug, x2_aug) losses.append(loss.detach().cpu().numpy()) for optimizer in non_CLUB_optims: optimizer.zero_grad() loss.backward() for optimizer in non_CLUB_optims: optimizer.step() for _ in range(num_club_iter): learning_loss = model.learning_loss(x1_batch, x2_batch, x1_aug, x2_aug) for optimizer in CLUB_optims: optimizer.zero_grad() learning_loss.backward() for optimizer in CLUB_optims: optimizer.step() if i_batch%100 == 0: print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item()) return