LMM / mogen /models /architectures /base_architecture.py
mingyuan's picture
initial commit
373af33
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.runner import BaseModule
from typing import Dict, Tuple, List
def to_cpu(x: torch.Tensor) -> torch.Tensor:
"""Move a tensor to CPU and detach it from the computation graph.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The tensor detached and moved to CPU.
"""
if isinstance(x, torch.Tensor):
return x.detach().cpu()
return x
class BaseArchitecture(BaseModule):
"""Base class for mogen architecture.
Args:
init_cfg (dict, optional): Initialization config for the module.
"""
def __init__(self, init_cfg: dict = None):
super(BaseArchitecture, self).__init__(init_cfg)
def forward_train(self, **kwargs):
"""Forward computation during training."""
pass
def forward_test(self, **kwargs):
"""Forward computation during testing."""
pass
def _parse_losses(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contains
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars)
- loss is the loss tensor which may be a weighted sum of all losses,
- log_vars contains all the variables to be logged.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_step(self, data: Dict, optimizer: torch.optim.Optimizer) -> Dict:
"""The iteration step during training.
This method defines an iteration step during training, excluding backpropagation
and optimizer updating, which are handled by an optimizer hook.
Args:
data (dict): The output of the dataloader.
optimizer (torch.optim.Optimizer): The optimizer object (unused).
Returns:
dict: A dictionary containing the loss, log_vars for logging, and the number of samples.
- ``loss``: A tensor for backpropagation, which may be a weighted sum of multiple losses.
- ``log_vars``: All the variables to be logged.
- ``num_samples``: The number of samples in the batch.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion']))
return outputs
def val_step(self, data: Dict, optimizer: torch.optim.Optimizer = None) -> Dict:
"""The iteration step during validation.
Args:
data (dict): The output of the dataloader.
optimizer (torch.optim.Optimizer, optional): The optimizer object (unused).
Returns:
dict: A dictionary containing the loss, log_vars for logging, and the number of samples.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion']))
return outputs
def forward(self, **kwargs):
"""Forward computation based on the training or testing mode."""
if self.training:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def split_results(self, results: Dict[str, torch.Tensor]) -> List[Dict]:
"""Split batched results into individual outputs.
Args:
results (dict): The batched results from the model containing 'motion', 'pred_motion', etc.
Returns:
list: A list of dictionaries where each dictionary contains results for a single instance.
"""
B = results['motion'].shape[0]
output = []
for i in range(B):
batch_output = dict()
batch_output['motion'] = to_cpu(results['motion'][i])
batch_output['pred_motion'] = to_cpu(results['pred_motion'][i])
batch_output['motion_length'] = to_cpu(results['motion_length'][i])
batch_output['motion'][batch_output['motion_length']:, :] = 0
batch_output['motion_mask'] = to_cpu(results['motion_mask'][i])
if 'pred_motion_length' in results:
batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i])
else:
batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i])
batch_output['pred_motion'][batch_output['pred_motion_length']:, :] = 0
if 'pred_motion_mask' in results:
batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i])
else:
batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i])
if 'motion_metas' in results:
motion_metas = results['motion_metas'][i]
if 'text' in motion_metas:
batch_output['text'] = motion_metas['text']
if 'token' in motion_metas:
batch_output['token'] = motion_metas['token']
if 'meta_data' in motion_metas and 'category_id' in motion_metas['meta_data']:
batch_output['category_id'] = motion_metas['meta_data']['category_id']
batch_output['motion_metas'] = motion_metas
output.append(batch_output)
return output